diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 02e2bcf4b9..f559928f8c 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -53,7 +53,7 @@ --set=full \ --kernel-name "GroupHadamardAmaxTmaKernel" \ -s 5 -c 5 \ - python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 """ @@ -173,7 +173,9 @@ def benchmark_linear( return timing_ms -def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None): +def run_benchmark_linear( + mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False +): data = [] assert not use_bias, "Bias is not supported for GroupedLinear benchmark" @@ -182,14 +184,14 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None device = "cuda" x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] - assert m % num_gemms == 0 - m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits + m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided # Bias is not supported for GroupedLinear benchmark bias = None # Run the benchmark print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") print(f"m_splits: {m_splits}") + print(f"fwd_only: {fwd_only}") grouped_fwd_bwd_timing_ms = benchmark_linear( x, @@ -197,7 +199,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None m_splits, bias, recipe_name, - mode="fwd_bwd", + mode="fwd_only" if fwd_only else "fwd_bwd", num_gemms=num_gemms, ) @@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None ] ) + timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms" + df = pd.DataFrame( data=data, columns=[ @@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None "n", "recipe", "num_gemms", - "grouped_fwd_bwd_time_ms", + timing_notation, ], ) @@ -234,7 +238,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None parser = argparse.ArgumentParser() parser.add_argument("--profile", action="store_true", help="Enable profiling mode") parser.add_argument( - "--output_dir", + "--output-dir", type=str, default="benchmark_output/", help="output path for report", @@ -266,6 +270,12 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None default=2048, help="Output dimension to use, default is 2048", ) + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Run forward pass only, default is both forward and backward passes", + ) args = parser.parse_args() jagged_input_splits = None @@ -297,7 +307,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None if jagged_input_splits is not None: num_gemms_list = [len(jagged_input_splits)] - token_dim_list = [65536] + token_dim_list = [16384, 32768, 65536, 98304] hidden_dim_list = [7168] output_dim_list = [2048] @@ -371,7 +381,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None recipe_name, use_bias, num_gemms=num_gemms, - m_splits=jagged_input_splits, + m_splits_provided=jagged_input_splits, + fwd_only=args.fwd_only, ) df_linears = pd.concat([df_linears, df]) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 10aa3eb505..a29dcb4279 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference( for i in range(len(x_qx)): if split_sections[i] == 0: - # then just assert the same same and dtype because the buffer won't be zero out + # then just assert the same shape and dtype because the buffer won't be zero out assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) @@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference( # assert with zero tolerance for i in range(len(x_qx_t)): if split_sections[i] == 0: - # then just assert the same same and dtype because the buffer won't be zero out + # then just assert the same shape and dtype because the buffer won't be zero out assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) @@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference( (1024, 256), # larger sizes (8192, 1024), + (16384, 8192), (16384, 16384), ], ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78..79948e28f7 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources hadamard_transform/group_hadamard_transform.cu hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu + hadamard_transform/group_hadamard_transform_cast_fusion.cu + hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu transpose/quantize_transpose_square_blockwise.cu diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 1ed46a3359..73467d7275 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); } } + +// Group quantize assumes contiguous inputs and outputs in memory allocation +// TODO (zhongbo): find a better way to make it a more generalized API +void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + + dispatch::group_quantize_fwd_helper(input, outputs, split_sections, + num_tensors, quant_config, stream); +} diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9f7a4a9b01..6d4454402c 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -19,6 +19,7 @@ #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" @@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens } } +template +void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_tensors; + for (size_t i = 0; i < num_tensors; ++i) { + output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + } + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + // Take the scaling mode of the first output tensor + auto scaling_mode = output_tensors[0]->scaling_mode; + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + // Skip checking output tensor list + // output list here is allowed to have empty tensor + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + + // Launch NVFP4 group quantize kernel + nvfp4::group_quantize_transpose( + *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + &quant_config_cpp, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh new file mode 100644 index 0000000000..28b47e32d2 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -0,0 +1,904 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace group_quantize_transpose_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed +struct MultiAmaxCastTransposeFusionArgs { + // Amax buffer for rowwise scaling + void *rowwise_amax_list[kMaxTensorsPerKernel]; + // Rowwise scale pointers with 128x4 padding included for rowwise scaling + void *output_rowwise_scale_inv_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) Amax buffer for colwise scaling + void *colwise_amax_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output data pointers for fp4 transposed output + void *output_colwise_data_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output scale inverse pointers for each tensor + void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output scale stride for colwise scaling + int output_colwise_scale_stride[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + +__device__ __forceinline__ int GetTensorId(MultiAmaxCastTransposeFusionArgs *kernel_args_ptr, + int offset) { + // check the kernel args and get the corresponding id + int tensor_id = 0; + while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) { + ++tensor_id; + } + return tensor_id; +} + +// Helper to get tensor id at offset, and also whether [offset_start, offset_end) crosses a split boundary. +__device__ __forceinline__ int GetTensorIdAndBoundary( + MultiAmaxCastTransposeFusionArgs *kernel_args_ptr, int offset_start, int offset_end, + bool *cross_boundary) { + int tensor_id_start = 0; + while (kernel_args_ptr->split_sections_range[tensor_id_start + 1] <= offset_start) { + ++tensor_id_start; + } + int tensor_id_end = tensor_id_start; + if (offset_end != offset_start) { + if (kernel_args_ptr->split_sections_range[tensor_id_start + 1] < offset_end) { + tensor_id_end = tensor_id_start + 1; + } + } + if (cross_boundary) { + *cross_boundary = (tensor_id_start != tensor_id_end); + } + return tensor_id_start; +} + +__device__ __forceinline__ void UpdateEncodeDecodeScaleFP32(float *amax_ptr, float *s_enc_ptr, + float *s_dec_ptr) { + float s_env_value = + (amax_ptr == nullptr) ? 1.0f : compute_global_encode_scaling_factor_FP4(*amax_ptr); + float s_dec_value = 1.0 / s_env_value; + *s_enc_ptr = s_env_value; + *s_dec_ptr = s_dec_value; + return; +} + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; + +// Each call generates 4x uint32_t random numbers +constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +template +__global__ void __launch_bounds__(THREADS_NUM) + group_quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + nvfp4_scale_t *const scales_ptr, const float *noop, + const size_t rows, const size_t cols, + const size_t scale_stride, const size_t *rng_state, + MultiAmaxCastTransposeFusionArgs kernel_args) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + // TODO(zhongbo): add back when transpose is supported + // const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + // const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + // TODO(zhongbo): add back when transpose is supported + // const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + // const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + // TODO(zhongbo): add back when transpose is supported + // const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + // const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + + // TODO(zhongbo): add back when transpose is supported + // const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // TODO (zhongbo): finish this + float *amax_rowwise_ptr = nullptr; + float *amax_colwise_ptr = nullptr; + nvfp4_scale_t *split_rowwise_scale_ptr = nullptr; + + // suppose the amax is fixed for the current 128x128 tile (need 128 padding) + bool need_update_tensor_id = true; + int tensor_id = GetTensorIdAndBoundary(&kernel_args, block_offset_Y, block_offset_Y + CHUNK_DIM_Y, + &need_update_tensor_id); + size_t split_start = kernel_args.split_sections_range[tensor_id]; + size_t split_end = kernel_args.split_sections_range[tensor_id + 1]; + amax_rowwise_ptr = reinterpret_cast(kernel_args.rowwise_amax_list[tensor_id]); + split_rowwise_scale_ptr = + reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); + + float S_enc_rowwise = 1.0f; + float S_dec_rowwise = 1.0f; + UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); + + // TODO (zhongbo): colwise scaling disabled for now because of transpose + float S_enc_colwise = 1.0f; + float S_dec_colwise = 1.0f; + if (amax_colwise_ptr != nullptr) { + UpdateEncodeDecodeScaleFP32(amax_colwise_ptr, &S_enc_colwise, &S_dec_colwise); + } else { + S_enc_colwise = S_enc_rowwise; + S_dec_colwise = S_dec_rowwise; + } + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + // for stages from 1 to STAGES - 1, we need to update the tensor id + // skip updating tensor id if it's the last CTA, and some stages will be out of bounds + if (need_update_tensor_id && stage > 0 && (block_offset_Y + stage_offset_Y < rows)) { + int new_tensor_id = GetTensorId(&kernel_args, block_offset_Y + stage_offset_Y); + if (new_tensor_id != tensor_id) { + tensor_id = new_tensor_id; + split_start = kernel_args.split_sections_range[tensor_id]; + split_end = kernel_args.split_sections_range[tensor_id + 1]; + amax_rowwise_ptr = reinterpret_cast(kernel_args.rowwise_amax_list[tensor_id]); + UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); + split_rowwise_scale_ptr = + reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); + // TODO (zhongbo): colwise scaling disabled for now because of transpose + // Skip fetching colwise amax pointer and scaling factor updates + } + } + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE) < chunk_rows; + + // TODO(zhongbo): depending on input padding multiple (whether 128 or 64), use either scale_ptr or split_rowwise_scale_ptr + // const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + // if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + // scales_ptr[scale_idx_global] = S_dec_b_fp8; + // } + + // Map to local split coordinates + const size_t split_rows = split_end - split_start; + const size_t local_scale_row = scales_offset_Y - split_start; + + // Local bounds: 0 <= local_scale_row < split_rows + const bool local_rowwise_scale_is_within_bounds_Y = local_scale_row < split_rows; + + // Index inside this split’s scale buffer + const size_t scale_idx_local = local_scale_row * scale_stride + scales_offset_X; + + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y && + local_rowwise_scale_is_within_bounds_Y) { + split_rowwise_scale_ptr[scale_idx_local] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + // TODO(zhongbo): add back when transpose is supported + // const size_t global_offset_Y_t = block_offset_Y_t; + // const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + // TODO(zhongbo): add back when transpose is supported + // if constexpr (RETURN_TRANSPOSE) { + // ptx::cp_async_bulk_tensor_2d_shared_to_global( + // reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + // global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + // } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // TODO(zhongbo): add back when transpose is supported + // Vectorized store scaling factors through SHMEM + // if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + // using ScalesVec = Vec; + // const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + // ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + // const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + // const size_t count = // number of scales in Y dimension of this chunk + // (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + // nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + // constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + // if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // // Fast path: vectorized store when destination is properly aligned + // scales_vec.store_to(dst); + // } else { + // // Safe path: element-wise store for tails or unaligned destinations + // scales_vec.store_to_elts(dst, 0, count); + // } + // } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace group_quantize_transpose_kernel + +template +void group_quantize_transpose(const Tensor &input, const Tensor *noop, + std::vector &output_list, const size_t *split_sections, + size_t num_tensors, const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace group_quantize_transpose_kernel; + using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + NVTE_CHECK(num_tensors == output_list.size(), + "Number of output tensors should match number of tensors."); + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + Tensor *output = nullptr; + // loop over the list to find the first non-empty tensor + for (size_t i = 0; i < num_tensors; ++i) { + if (output_list[i]->has_data()) { + output = output_list[i]; + break; + } + } + NVTE_CHECK(output != nullptr, "No output tensor found."); + // also check that the output has not null data pointer + NVTE_CHECK(output->data.dptr != nullptr, "Output data pointer is null."); + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + bool return_transpose = output->has_columnwise_data(); + // forbid return transpose for now because group quantize transpose is not supported yet + NVTE_CHECK(!return_transpose, "Return transpose is not supported for group quantize transpose."); + + // output_List is contiguous in memory, so take the first tensor as the contiguous output + auto output_contiguous = output->data; + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + // process the output list and produce the multi-tensor args for grouped kernel + MultiAmaxCastTransposeFusionArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + for (size_t i = 0; i < num_tensors; ++i) { + if (split_sections[i] == 0) { + continue; + } + kernel_args.rowwise_amax_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->amax.dptr); + kernel_args.output_rowwise_scale_inv_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->scale_inv.dptr); + // kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + // check overflow + NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, + "split_sections_range overflow the int32_t"); + kernel_args.num_tensors++; + } + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + // Note (zhongbo): for group quantize of [x1, x2, ..., xn] + // for the rowwise sclaing, scaling factor stride is shared between all tensors + // for the colwise scaling, scaling factor stride is different for each tensor because of transpose + // since transpose puts token dimension splits in the last dimension of the tensor + const size_t scale_stride = output->scale_inv.shape[1]; + // const size_t scale_stride_transpose = + // return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + // alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output_contiguous, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, 4); + // if (return_transpose) { + // create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + // BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + // } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + group_quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + NVTE_ERROR("2D quantization is not supported for group quantize transpose."); + } + + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + kernel<<>>(tensor_map_input, tensor_map_output, + scales_ptr, noop_ptr, rows, cols, + scale_stride, rng_state, kernel_args); + NVTE_CHECK_CUDA(cudaGetLastError()); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 38b437b994..0e264eaae3 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -394,6 +394,7 @@ struct QuantizationConfig { NVTETensor rng_state = nullptr; bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; + bool use_fast_math = false; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales @@ -402,7 +403,8 @@ struct QuantizationConfig { sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format sizeof(NVTETensor), // rng_seed and offset sizeof(bool), // nvfp4_2d_quantization - sizeof(bool) // stochastic_rounding + sizeof(bool), // stochastic_rounding + sizeof(bool) // use_fast_math }; }; diff --git a/transformer_engine/common/hadamard_transform/customized_pipeline.cuh b/transformer_engine/common/hadamard_transform/customized_pipeline.cuh new file mode 100644 index 0000000000..b6f6799a49 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/customized_pipeline.cuh @@ -0,0 +1,222 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_ +#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_ + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; +namespace detail { +// Producer-consumer pipeline implementation +// for UMMA producer. In this case, UMMA barrier arrives are used +// by producer_commit. Use case, accumulator generation as +// the result of MMA instructions. +template , + class AtomThrShape_MNK_ = Shape<_1, _1, _1> > +class CustomizedPipelineTmaUmmaAsync { + public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + + private: + using Impl = PipelineTmaAsync; + + public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params, + ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + + uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 + if (cute::size(cluster_shape) > 1) { + multicast_consumer_arrival_count = + ((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) * + params.num_consumers; + } + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && + "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); + cutlass::arch::detail::initialize_barrier_array_pair_aligned< + decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, + multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, + dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + block_id_mask_ = detail::calculate_multicast_mask( + cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask( + cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } else { + block_id_mask_ = detail::calculate_multicast_mask( + cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + // Constructor by default initializes barriers and calculates masks. + // These operations can be explicity deferred by specifying InitBarriers and InitMasks. + // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. + template + CUTLASS_DEVICE CustomizedPipelineTmaUmmaAsync(SharedStorage& storage, Params params, + ClusterShape cluster_shape, InitBarriers = {}, + InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, InitMasks{}), + params_(params), + empty_barrier_ptr_(&storage.empty_barrier_[0]), + full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || + cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || + cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return impl_.producer_try_acquire(state, skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, + ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) { + impl_.producer_expect_transaction(state, transaction_bytes); + } + + // NOP for TMA based mainloop + CUTLASS_DEVICE + void producer_commit(PipelineState state, uint32_t bytes) { impl_.producer_commit(state, bytes); } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { impl_.producer_tail(state); } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_try_wait(state, skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void umma_consumer_release(PipelineState state) { umma_consumer_release(state.index(), false); } + CUTLASS_DEVICE + void consumer_release(PipelineState state) { impl_.consumer_release(state); } + + private: + Impl impl_; + Params params_; + EmptyBarrier* empty_barrier_ptr_; + FullBarrier* full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notified. + CUTLASS_DEVICE + void umma_consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + // {$nv-release-never begin} + // TODO: Needs to be updated once Blackwell specialized pipeline is implemented. + // XMMA style bar_peek will be tested. We will need to revisit skip interface and + // what skip means when we have bar_peek functionality. + // A separate MR will implement MMA_2x1SM specialized pipeline. + // {$nv-release-never end} + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } else { + if (!skip) { + if constexpr (cute::is_static_v && size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; +} // namespace detail +} // namespace cutlass + +#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_ diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 84eb6bb5c3..ea5e22bbfb 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -459,7 +459,7 @@ void group_hadamard_transform_amax(const Tensor& input_, std::vector& o } // Multi zero out multiple amaxes if needed - // Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel + // Currently don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel // let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block dim3 block_setup_amax(kMaxTensorsPerKernel); dim3 grid_setup_amax(1); diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..6e071ec79f --- /dev/null +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1003 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; +using cute:: + Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +using Stride2D = cute::Stride>; + +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed +struct MultiAmaxHadamardCastFusionArgs { + // (output) Amax buffer for pre-RHT amax buffer + void *global_amax_list[kMaxTensorsPerKernel]; + // output C pointers for each tensor + void *output_colwise_list[kMaxTensorsPerKernel]; + // output scale inverse pointers for each tensor + void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + // split sections of each tensor of input + int split_sections[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + // stride 2D struct for CUTE + Stride2D output_stride2d_list[kMaxTensorsPerKernel]; + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + +__device__ __forceinline__ float *GetGlobalAmaxPtrByTensorId( + MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, int tensor_id) { + // directly returns the global amax pointer by tensor id + if (tensor_id < 0 || tensor_id >= kernel_args_ptr->num_tensors) { + return nullptr; + } + return reinterpret_cast(kernel_args_ptr->global_amax_list[tensor_id]); +} + +__device__ __forceinline__ int GetTensorId(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + int offset) { + // Check the kernel args and get the corresponding id + const int num_tensors = kernel_args_ptr->num_tensors; + if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) { + return num_tensors - 1; + } + int tensor_id = 0; + while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) { + ++tensor_id; + } + return tensor_id; +} + +// calculate the global encode scale factor for a given global amax. +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float kFP8E4M3Max = 448.0f; + constexpr float kFP4E2M1Max = 6.0f; + // If scale is infinity, return max value of float32 + float global_encode_scale = cutlass::minimum_with_nan_propagation{}( + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + // If global amax is 0 or infinity, return 1 + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +struct SharedStorage { + static constexpr int AccumulatorPipelineStageCount = 16; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = + cutlass::PipelineTmaUmmaAsync, AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; +}; + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + auto output_ptr = reinterpret_cast(&output); + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverter( + cutlass::Array const &input, cutlass::Array const *rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +__global__ static void group_rht_gemm_device( + MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, TA const *A, AStride dA, + ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, TB const *B, BStride dB, + BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, CSmemLayout, TiledMMA mma, + MultiAmaxHadamardCastFusionArgs kernel_args, const size_t *rng_state) { + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = + cutlass::PipelineTmaUmmaAsync, AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16, 16)); + + using TensorC = decltype(make_tensor(subbyte_iterator(recast_ptr(nullptr)), // engine + make_shape(int{}, int{}), // (M, N_i) + Stride2D{} // stride (dM, dN) + )); + + using TensorSFC = decltype(make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(make_shape(int{}, // M + make_shape(make_shape(Int<16>{}, _4{}), // (16, 4) + int{}) // n_tiles = split / 64 + ), + make_stride(int{}, // dM = (split / 16) + make_stride(make_stride(_0{}, _1{}), // inner (16,4) layout + _4{}) // tiles stride + )))); + + auto cluster_shape = Shape<_1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + auto mainloop_tiler = Shape<_128, _16, _64>{}; + auto epilogue_tiler = Shape<_128, _64, _64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + // Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + using TensorGC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, + make_coord(_, _, _), Step<_1, _1, X>{})); + + using TensorGSFC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, + make_coord(_, _, _), Step<_1, _1, X>{})); + + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma( + SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0, 2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0, 2>(epilogue_tiler)); + + auto bulk_tmem_mma = + TiledMMA::make_fragment_C(append(acc_shape_mma, Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C( + append(acc_shape_epilogue, Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + 32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + // if (is_epilogue_warp && elect_one_sync()) { + // // prefetch to make the global amax in cache + // for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + // cute::prefetch(raw_pointer_cast(kernel_args.global_amax_list[i])); + // } + // } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, tile_idx_m, _); + int k_tile = 0; + auto barrier_token = + mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = + mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n;) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = + bulk_tmem_mma(_, _, _, accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_, _, k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + + // get global amax pointer + int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); + float *global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); + + TC *cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + TSFC *cur_output_colwise_scale_inv_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + int cur_output_colwise_n = kernel_args.split_sections[tensor_id]; + + TensorC cur_mC = + cute::make_tensor(cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id]); + + auto cur_sfc_shape = + make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); + + auto cur_sfc_stride = + make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); + + TensorSFC cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride)); + + TensorGC cur_gC_mn = + local_tile(cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + + TensorGSFC cur_gSFC_mn = local_tile( + cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + ); + + Tensor tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); + + float global_amax_val = *global_amax_ptr; + float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + + float global_decode_scale = 1.0f / global_encode_scale; + + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + // get the starting index of current k-tile in global tensor, to query the correct global amax + int cur_k_tile_global_elem_idx = (tile_idx_n + k_tile) * 64; + int new_tensor_id = GetTensorId(&kernel_args, cur_k_tile_global_elem_idx); + // float* new_global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, cur_k_tile_global_elem_idx); + global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, new_tensor_id); + // update the scaling factors when it's no longer the same amax pointer + // TODO(zhongbo): the math operations are very expensive + // since the kernel is persistent, we can have a cache for all the possible scaling factors + if (tensor_id != new_tensor_id) { + global_amax_val = *global_amax_ptr; + global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + global_decode_scale = 1.0f / global_encode_scale; + tensor_id = new_tensor_id; + // went through the cute operations to update the local tensors + cur_output_colwise_ptr = + reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + cur_output_colwise_scale_inv_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + cur_output_colwise_n = kernel_args.split_sections[tensor_id]; + + cur_mC = cute::make_tensor( + cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id]); + + cur_sfc_shape = + make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); + + cur_sfc_stride = + make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); + + cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride)); + + cur_gC_mn = local_tile( + cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + + cur_gSFC_mn = local_tile(cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} + // (BLK_M, BLK_N-like) + ); + + tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); + } + // maybe udpated to the new tensor id + int tensor_start_elem = kernel_args.split_sections_range[tensor_id]; + int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64; + + Tensor tCgC_mn = tCgC(_, _, _, tile_idx_m, local_tile_idx_n); + Tensor tCgSFC_mn = cur_gSFC_mn(_, _, tile_idx_m, local_tile_idx_n); + + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = + make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor( + tCgSFC_mn.data(), make_layout(make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}))); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with unfused + // kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + } + + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = + cutlass::divides>{}(1.0, qpvscale_scaled); + } + + // Initialize RNG for tile + const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + + // copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrC, tDgC); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } +} + +// this function computes RHT-GEMM for +// A: m x n: col-major +// B: 16 x 16: row-major +// C: m x n: row-major +// SFC: m x (n/16): row-major +template +void group_rht_gemm_ntt_w_sfc(int m, int n, TA const *A, TB const *B, + MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + const size_t *rng_state, uint32_t sm_count, cudaStream_t stream, + int k_tile_size = 2048) { + using namespace cute; + + // Define shapes (dynamic) + auto M = static_cast(m); + auto N = static_cast(n); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dB = make_stride(Int<1>{}, 16); // (dN,dK) + for (size_t i = 0; i < kernel_args_ptr->num_tensors; ++i) { + kernel_args_ptr->output_stride2d_list[i] = + make_stride(kernel_args_ptr->split_sections[i], Int<1>{}); + } + + auto cga_shape = Shape<_1, _1, _1>{}; + auto cga_tile_shape = Shape<_128, _16, _16>{}; + auto cluster_tile_mainloop = Shape<_128, _16, _64>{}; + + // Construct the MMA + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS{}, + Layout>{}); + + // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div( + shape<0>(cga_tile_shape), + shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cga_tile_shape), + shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + auto mma_shape_A = partition_shape_A( + mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = + decltype(shape_div(shape<0>(cluster_tile_mainloop), + shape_div(shape<0>(cluster_tile_mainloop), + size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, + append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy + + // Create GMEM tensors + Tensor tensorA = make_tensor(A, make_layout(make_shape(M, N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16, 16), dB)); // (16,16) + + // Create the TiledCopy + + auto tma_load_a = + make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); + auto tma_load_b = + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cga_tile_shape, mma); + + // Assert checks on tile sizes -- no predication + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, "Inner dimension must be divisible by ", + static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, "Outer dimension must be divisible by ", + 4 * static_cast(size<1>(cga_tile_shape)), " but got ", N, "."); + + uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); + + tiles = (tiles < sm_count) ? tiles : sm_count; + + dim3 dimBlock(256); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles, 1, 1); + + int smem_size = sizeof(SharedStorage); + auto *kernel_ptr = &group_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), TA, decltype(dA), + decltype(sA), decltype(tma_load_a), TB, decltype(dB), decltype(sB), decltype(tma_load_b), TC, + Stride2D, decltype(sC), TSFC, decltype(mma), kEnableStochasticRounding, kUseFastMath>; + + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + (*kernel_ptr)<<>>(M, N, k_tile_size, cga_tile_shape, A, dA, + sA, tma_load_a, B, dB, sB, tma_load_b, sC, + mma, *kernel_args_ptr, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// this function is used to wrap the group_rht_gemm_ntt_w_sfc function +// to transpose the input tensor A +template +void group_rht_gemm_ttt_wrapper(int m, int n, TA const *A, TB const *B, + MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + const size_t *rng_state, uint32_t sm_count, cudaStream_t stream, + int k_tile_size = 1024) { + // in addition to transpose the input tensor A + // we also need to reshape m, n to at best + // ultilize as many SMs as possible while keeping + // a relatively large contiguous dimension. + // for example, after swapping m, n for transpose purposes, + // the input / output tensor shapes for RHT-GEMM are: + // A: n x m: col-major + // B: 16 x 16: row-major + // C: n x m: row-major + // SFC: n x (m/16): row-major + group_rht_gemm_ntt_w_sfc( + n, m, A, B, kernel_args_ptr, rng_state, sm_count, stream, k_tile_size); +} + +} // namespace +} // namespace detail + +void group_hadamard_transform_cast_fusion_columnwise( + const Tensor &input_, std::vector &output_list, const size_t *split_sections, + size_t num_tensors, const Tensor &hadamard_matrix_, QuantizationConfig &quant_config, + cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_cast_fusion_columnwise); + + using transformer_engine::detail::kMaxTensorsPerKernel; + using transformer_engine::detail::MultiAmaxHadamardCastFusionArgs; + + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + + NVTE_CHECK(output_list.size() == num_tensors, + "Number of output tensors should match number of tensors."); + + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + // construct the multi-tensor args + MultiAmaxHadamardCastFusionArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(split_sections[i] % 64 == 0, "component ", i, + " of split_sections should be 64 multiple"); + if (split_sections[i] == 0) { + continue; + } + kernel_args.global_amax_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->amax.dptr); + // TODO(zhongbo): should we change API assumption to use columnwise_data instead of data? + kernel_args.output_colwise_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->data.dptr); + kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->scale_inv.dptr); + kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + kernel_args.num_tensors++; + } + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + detail::group_rht_gemm_ttt_wrapper( + /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size););); +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_cast_fusion_columnwise( + const NVTETensor input, NVTETensor *outputs, const NVTETensor hadamard_matrix, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_hadamard_transform_cast_fusion_columnwise); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + + Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_list(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list[i] = convertNVTETensorCheck(outputs[i]); + } + + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Call the multi-tensor Hadamard transform amax implementation. + group_hadamard_transform_cast_fusion_columnwise( + *input_tensor, output_list, split_sections, num_tensors, + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..3932b328ae --- /dev/null +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1499 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "customized_pipeline.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/platform/platform.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; + +// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute::Tensor; + +constexpr int kMaxTensorsPerKernel = 64; + +struct MultiAmaxHadamardCastFusionArgs { + // (output) Amax buffer for input A amax buffer + void *global_a_amax_list[kMaxTensorsPerKernel]; + // (output) Amax buffer for pre-RHT amax buffer + void *global_d_amax_list[kMaxTensorsPerKernel]; + // output D pointers for each tensor + void *output_colwise_list[kMaxTensorsPerKernel]; + // output SFD inverse pointers for each tensor + void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + // split sections of each tensor of input + int split_sections[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + +__device__ __forceinline__ int GetGroupIdx(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + int offset) { + // Check the kernel args and get the corresponding id + const int num_tensors = kernel_args_ptr->num_tensors; + if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) { + return num_tensors - 1; + } + int group_idx = 0; + while (kernel_args_ptr->split_sections_range[group_idx + 1] <= offset) { + ++group_idx; + } + return group_idx; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + auto output_ptr = reinterpret_cast(&output); + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverter( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(&rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +struct SharedStorage { + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineStorage = typename SchedPipeline::SharedStorage; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineStorage = typename SchedThrottlePipeline::SharedStorage; + + struct TensorStorage : cute::aligned_struct<128, _1> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + alignas(16) SchedPipelineStorage sched; + alignas(16) SchedThrottlePipelineStorage sched_throttle; + alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_]; + alignas(16) float global_a_amax[kMaxTensorsPerKernel]; + alignas(16) float global_d_amax[kMaxTensorsPerKernel]; + uint32_t atomic_tile_counter[SchedulerPipelineStageCount_]; + uint32_t tmem_base_ptr; +}; + +// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support +template +__launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( + MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile, + TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, MultiAmaxHadamardCastFusionArgs args, + uint32_t *tile_scheduler_workspace, TiledMMA mma, const size_t *rng_state) { + using namespace cute; + + // Abort immediately if compilation is not supported + constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; + if constexpr (!is_blackwell_arch) { + NVTE_DEVICE_ERROR( + "group_row_col_rht_gemm_device is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); + return; + } + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + "group_row_col_rht_gemm_device must generate row-wise " + "and/or column-wise output."); +#if !defined(CUTLASS_ARCH_CLC_ENABLED) + CUTLASS_NOT_IMPLEMENTED(); + return; +#endif + + using X = Underscore; + // Accumulator data type for main computation + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; + static constexpr bool kUseFastMath = kUseFastMath_; + + // Constant for RHT tensor processing (tile size etc) + static int constexpr RhtTensorSize = 16; + + // Transaction bytes for TMA transfer on RHT tensor blocks + static int constexpr kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + // Mainloop pipeline stage calculation, vectorization parameters for scaling factors + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + static int constexpr SFVecSize = 16; + // Swizzle output layout for scaling factor arrays + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + + // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineState = typename SchedPipeline::PipelineState; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + + // Compile-time safety: static shapes required for shared memory layouts + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // CUTE_STATIC_ASSERT(is_static::value); + + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128, _16, _128>{}; + auto epilogue_tiler = Shape<_128, _128, _128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + + struct TileScheduler { + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + uint32_t linear_idx = 0; + uint32_t next_linear_idx = 0; + uint32_t start_idx = 0; + uint32_t tile_m_idx = 0; + uint32_t tile_n_idx = 0; + int k_tile_max = 0; + uint32_t *atomic_tile_index_; + uint32_t *smem_tile_counter; + uint32_t atomic_offset; + cutlass::FastDivmodU64 divmod_tiles_in_m; + + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { + update_tile_idx(); + } + CUTLASS_DEVICE void update_tile_idx() { + uint64_t q, r; + divmod_tiles_in_m(q, r, uint64_t(linear_idx)); + tile_m_idx = static_cast(r); + tile_n_idx = static_cast(q) * uint32_t(k_tile_max); + } + CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } + CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } + + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + + CUTLASS_DEVICE bool is_valid() const { + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), + cute::make_coord(tiles_in_m, tiles_in_n)); + } + + CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } + + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } + + // Fetch a new tile_id using atomics. + CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { + uint32_t tile_id_counter = 0; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"(tile_id_counter) + : "l"(atomic_tile_index_), "r"(pred)); + + return tile_id_counter; + } + + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_consumer_state) { + sched_pipeline.consumer_wait(sched_pipeline_consumer_state); + next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; + cutlass::arch::fence_view_async_shared(); + sched_pipeline.consumer_release(sched_pipeline_consumer_state); + return; + } + + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_producer_state) { + uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + sched_pipeline.producer_acquire(sched_pipeline_producer_state); + auto is_leading_thread = cute::elect_one_sync(); + uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + if (is_leading_thread) { + cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + } + + ++sched_pipeline_producer_state; + return sched_pipeline_producer_state; + } + + CUTLASS_DEVICE auto update_work_tile_info() { + linear_idx = next_linear_idx; + update_tile_idx(); + return; + } + }; + + // Allocate and alias shared memory to the kernel's shared storage type + extern __shared__ char shared_memory[]; + using SharedStorage = + SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + + // Compute the number of tiles in M and N after tiling and assign scheduler + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t( + size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler)))); + + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, + shared_storage.atomic_tile_counter); + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Shapes for accumulated tiles in mainloop and epilogue + auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + + // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + // Number of threads assigned for various epilogue roles depending on quantization settings + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = + NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + using AccumulatorPipelineInitBarriers = cute::bool_constant; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = + size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, + cluster_shape, AccumulatorPipelineInitBarriers{}, + cute::true_type{}); // Delay mask calculation + typename SchedPipeline::Params sched_pipeline_params; + if (is_sched_warp) { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; + } else { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; + } + sched_pipeline_params.producer_blockid = 0; + sched_pipeline_params.producer_arv_count = 1; + sched_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.transaction_bytes = sizeof(uint32_t); + sched_pipeline_params.initializing_warp = 3; + SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); + SchedPipelineState sched_pipeline_consumer_state; + SchedPipelineState sched_pipeline_producer_state = + cutlass::make_producer_start_state(); + + typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; + if (is_dma_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; + } + sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + sched_throttle_pipeline_params.dst_blockid = 0; + sched_throttle_pipeline_params.initializing_warp = 4; + + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, + sched_throttle_pipeline_params); + SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = + cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer + if (is_dma_warp) { + // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). + cutlass::arch::warpgroup_reg_dealloc<32>(); + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + // Partition tensors for tiling according to the mainloop and cluster tilers. + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + + // Shared memory tensors for pipeline + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + // Determine warp/tile positioning + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Partition global to local fragments for A and B + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + } + + do { + // is_first_wave indicates whether this scheduler wave is the first among a group. + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); + int k_tile = 0; + + sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); + sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); + ++sched_pipeline_throttle_producer_state; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } + } + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + // scheduler.advance(); + } while (scheduler.is_valid()); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + // Setup shared memory fragments for A and B tiles. + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + // Wait until the B (Hadamard) tensor copy is complete + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { + int accumulator_k_block = + accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); + gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } else if (is_sched_warp) { + // Scheduler warp manages tile assignment and pipeline progress for warps + cutlass::arch::warpgroup_reg_dealloc<32>(); + do { + sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); + sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); + ++sched_pipeline_throttle_consumer_state; + sched_pipeline_producer_state = + scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, + // and writing result tensors/scales to global memory. + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + auto acc_epilogue_pipelined_shape = + append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // Use 256-bit fragments for aligned bulk stores + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + // Wait for TMEM allocation for this pipeline to finish + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + // g2s load all global_d_amax + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueColQuantThreadCount) { + shared_storage.global_d_amax[g] = + __ldg(reinterpret_cast(args.global_d_amax_list[g])); + } + + size_t rng_seed = 0; + size_t rng_offset = 0; + // Setup RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + + // Determine quantization scale factor layouts/output splits for this group + TSFDLayout sfd_layout; + int cur_N = args.split_sections[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // Build output tensors for columns and their quant scales + Tensor mD = make_tensor( + cute::subbyte_iterator(reinterpret_cast(args.output_colwise_list[group_idx])), + make_shape(M, cur_N), DStride{}); // (M,packed_N) + Tensor gD_mn = + local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + args.output_colwise_scale_inv_list[group_idx])), + sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + + // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float c_global_amax_val = shared_storage.global_d_amax[group_idx]; + float global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + + do { + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + c_global_amax_val = shared_storage.global_d_amax[group_idx]; + // update amax + global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + cur_N = args.split_sections[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = + tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = + make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // update tensor + mD = make_tensor(cute::subbyte_iterator( + reinterpret_cast(args.output_colwise_list[group_idx])), + make_shape(M, cur_N), DStride{}); + gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + args.output_colwise_scale_inv_list[group_idx])), + sfd_layout); + gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + } + int group_start_offset = args.split_sections_range[group_idx]; + int local_tile_n_idx = + (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); + Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); + + Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = + make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + // Compute amax and quantization scales for this tile + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // Copy from TMEM to registers + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); + } + + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tD_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); + } + + // Prepare stochastic rounding random state if enabled + uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + // Apply round/quantize to each fragment, with or without stochastic rounding + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } + } + + // Write quantized FP4 tile and dequant scale to gmem + copy(tiled_r2g, src, dst); + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + // g2s load all global_a_amax for all groups/tensors + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueRowQuantThreadCount) { + shared_storage.global_a_amax[g] = + __ldg(reinterpret_cast(args.global_a_amax_list[g])); + } + // RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // Input/output tensors/partitions for row quant warp + Tensor mQA = + make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_N) + // Swizzled shared memory A tile, with layout + Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( + coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + + // Set up layouts for partitioning – tile-by-warp, with vector granularity + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + using R2GAtomSFA = Copy_Atom; + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + // Allocate temporary register tensors for copying quantization => output + Tensor tQArA = make_tensor_like( + make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + + Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + + int row_quant_barrier_id = 10; + cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + + int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + float a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfa_converter = cutlass::NumericConverter{}; + do { + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Update group quantization parameters/scaling + global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + } + + auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + ++mainloop_pipe_consumer_state; + ++k_tile; + + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = + reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); + Tensor amax = + make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); + Tensor pvscales = make_tensor_like(amax); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + + tiles_in_m * tiles_in_n * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { + auto amax_view = group_modes<1, rank(amax)>(amax); + auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); + auto compute_frgs_up = + cutlass::NumericArrayConverter{}( + compute_frgs[v]); + amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales_view(_0{}, v) = + cutlass::divides{}(amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + pvscales_view(_0{}, v), global_encode_scale); + } + filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); + auto qpvscale_ups = + cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = + cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale)); + } + } + copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); + copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + } + // scheduler.advance(); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } +} // NOLINT(readability/fn_size) + +template +void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A, + TB const *B, TQA *QA, TSFA *SFA, + MultiAmaxHadamardCastFusionArgs &args, + const size_t *rng_state, uint32_t sm_count, + cudaStream_t stream, int k_tile_size = 1024) { + using namespace cute; + static int constexpr SFVecSize = 16; + static int constexpr RhtTensorSize = 16; + + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); + using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), + make_stride(make_stride(_0{}, _1{}), 0))); + using LinearSFDLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), + make_stride(0, make_stride(_0{}, _1{})))); + + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFALayout = decltype(tile_to_shape( + SwizzledSFALayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{})); + using SwizzledSFDLayout = decltype(tile_to_shape( + SwizzledSFDLayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{})); + + using SFALayout = cute::conditional_t; + using SFDLayout = cute::conditional_t; + SFALayout sfa_layout; + SFDLayout sfd_layout; + + if constexpr (kEnableSwizzleSFOutput) { + sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{}); + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{}); + } else { + sfa_layout = make_layout( + make_shape(make_shape(Int{}, hidden_size / SFVecSize), packed_sequence_length), + make_stride(make_stride(_0{}, _1{}), hidden_size / SFVecSize)); + sfd_layout = make_layout( + make_shape(hidden_size, make_shape(Int{}, packed_sequence_length / SFVecSize)), + make_stride(packed_sequence_length / SFVecSize, make_stride(_0{}, _1{}))); + } + + // Define shapes (dynamic) + auto M = hidden_size; + auto N = packed_sequence_length; + Tensor tensorA = make_tensor(A, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorSFA = make_tensor(SFA, sfa_layout); + + // Define strides (from tensors) + auto dA = stride(tensorA); // (dM,dK) + auto dB = stride(tensorB); // (dN,dK) + auto dD = LayoutRight{}; // (dM,dN) + auto dQA = stride(tensorQA); // (dM,dK) + using ClusterShape = Shape<_1, _1, _1>; + auto cluster_shape = ClusterShape{}; + auto cluster_tile_shape = Shape<_128, Int, Int>{}; + auto cluster_tile_mainloop = Shape<_128, Int, _128>{}; + + // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); + // Construct the MMA + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div( + shape<0>(cluster_tile_shape), + shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cluster_tile_shape), + shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); + + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + auto mma_shape_A = partition_shape_A( + mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = + decltype(shape_div(shape<0>(cluster_tile_mainloop), + shape_div(shape<0>(cluster_tile_mainloop), + size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + static uint32_t constexpr TotalTmemRows = 128; + static uint32_t constexpr Sm100TmemCapacityColumns = 512; + static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static uint32_t constexpr AccumulatorPipelineStageCount = + TotalTmem / (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape)); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int SchedulerPipelineStageCount = 4; + static int constexpr MainloopPipelineBytes = sizeof( + typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>, + Shape<_1, _1, _1>>::SharedStorage); + + static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount; + static int constexpr SchedulerThrottlePipelineBytes = + sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr SchedulerPipelineBytes = + sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); + static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); + static int constexpr AccPipelineBytes = sizeof( + typename cutlass::PipelineUmmaAsync>::SharedStorage); + static int constexpr TmemBasePtrsBytes = sizeof(uint32_t); + static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes + static int constexpr kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + + SchedulerPipelineBytes + TmemBasePtrsBytes + + TmemDeallocBytes + BTensorBytes + + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP), + Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1) + auto sD = Layout<_1>{}; // XXX Dummy + + auto tma_load_a = + make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); + auto tma_load_b = + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cluster_tile_shape, mma); + + // Assert checks on tile sizes -- no predication + assert(M % size<0>(cluster_tile_shape) == 0); + assert(N % size<1>(cluster_tile_shape) == 0); + + dim3 dimBlock(512); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(sm_count, 1, 1); + + int smem_size = sizeof( + SharedStorage); + + auto *kernel_ptr = &group_row_col_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cluster_shape), + decltype(cluster_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB, + decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD, + decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma), + AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, + kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>; + + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Allocate workspace and set to zero + void *tile_scheduler_workspace = nullptr; + NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream)); + + // Launch kernel + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; + cutlass::Status status = cutlass::launch_kernel_on_cluster( + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, + sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, + tile_scheduler_workspace, mma, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); + + NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream)); +} + +} // namespace +} // namespace detail + +void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector &output_list, + const size_t *split_sections, size_t num_tensors, + const Tensor &hadamard_matrix_, + QuantizationConfig &quant_config, cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_cast_fusion); + + using transformer_engine::detail::kMaxTensorsPerKernel; + using transformer_engine::detail::MultiAmaxHadamardCastFusionArgs; + + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + + NVTE_CHECK(output_list.size() == num_tensors, + "Number of output tensors should match number of tensors."); + + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + // construct the multi-tensor args + MultiAmaxHadamardCastFusionArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + bool all_has_row_quant = true; + bool all_has_col_quant = true; + void *rowwise_data_base_ptr = nullptr; + void *rowwise_scale_inv_base_ptr = nullptr; + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(split_sections[i] % 128 == 0, "component ", i, + " of split_sections should be 128 multiple"); + if (split_sections[i] == 0) { + continue; + } + bool has_row_quant = output_list[i]->data.dptr != nullptr; + bool has_col_quant = output_list[i]->columnwise_data.dptr != nullptr; + all_has_row_quant = all_has_row_quant && has_row_quant; + all_has_col_quant = all_has_col_quant && has_col_quant; + // sanity check, the two bool flags cannot be both false + NVTE_CHECK(has_row_quant || has_col_quant, + "At least one of the output tensors must have row or column quant."); + void *amax_rowwise_ptr = + has_row_quant ? reinterpret_cast(output_list[i]->amax.dptr) : nullptr; + void *amax_colwise_ptr = + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_amax.dptr) : nullptr; + void *rowwise_data_ptr = + has_row_quant ? reinterpret_cast(output_list[i]->data.dptr) : nullptr; + void *rowwise_scale_inv_ptr = + has_row_quant ? reinterpret_cast(output_list[i]->scale_inv.dptr) : nullptr; + if (all_has_row_quant && + (rowwise_data_base_ptr == nullptr || rowwise_scale_inv_base_ptr == nullptr)) { + rowwise_data_base_ptr = rowwise_data_ptr; + rowwise_scale_inv_base_ptr = rowwise_scale_inv_ptr; + } + void *output_colwise_ptr = + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_data.dptr) : nullptr; + void *output_colwise_scale_inv_ptr = + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_scale_inv.dptr) + : nullptr; + kernel_args.global_a_amax_list[kernel_args.num_tensors] = amax_rowwise_ptr; + kernel_args.global_d_amax_list[kernel_args.num_tensors] = amax_colwise_ptr; + kernel_args.output_colwise_list[kernel_args.num_tensors] = output_colwise_ptr; + kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] = + output_colwise_scale_inv_ptr; + kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + kernel_args.num_tensors++; + } + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (use_stochastic_rounding) { + NVTE_CHECK(quant_config.rng_state != nullptr, + "Enabled stochastic rounding without providing RNG state"); + const Tensor &rng_state_tensor = *convertNVTETensorCheck(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TD = cutlass::float_e2m1_t; + using TSFD = cutlass::float_ue4m3_t; + using TQA = TD; + using TSFA = TSFD; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + const bool use_swizzle_sf_output = false; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_col_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_row_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::group_row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kUseFastMath>( + /*packed_sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*args=*/kernel_args, + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor *outputs, + const NVTETensor hadamard_matrix, + const size_t *split_sections, + const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + + Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_list(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list[i] = convertNVTETensorCheck(outputs[i]); + } + + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Call the multi-tensor Hadamard transform amax implementation. + group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors, + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, + stream); +} diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 12f02dba6b..11325041ae 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -29,7 +29,6 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/command_line.h" -#include "cutlass/util/helper_cuda.hpp" #include "cutlass/util/print_error.hpp" // clang-format off @@ -129,7 +128,8 @@ template + bool kEnableStochasticRounding = false, + bool kUseFastMath = false> __global__ static void rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, @@ -426,7 +426,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); const float global_decode_scale = 1.0f / global_encode_scale; - auto sfd_converter = cutlass::NumericConverter{}; + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + static constexpr float fp4_max_inv = 1.0f / fp4_max; + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } do { for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { @@ -469,10 +475,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ++accumulator_pipe_consumer_state; - // Cast data from FP32 to BF16 to FP32. - auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; - tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with unfused + // kernels + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + } auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); @@ -481,14 +490,27 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + } auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); - auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + } // Initialize RNG for tile const size_t rng_sequence @@ -532,7 +554,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template +template void rht_gemm_ntt_w_sfc(int m, int n, TA const* A, @@ -644,16 +666,15 @@ rht_gemm_ntt_w_sfc(int m, int n, TC, decltype(dC), decltype(sC), TSFC, decltype(mma), - kEnableStochasticRounding>; + kEnableStochasticRounding, + kUseFastMath>; - bool status = cudaFuncSetAttribute(*kernel_ptr, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) + ); - if (status != cudaSuccess) { - std::cerr << "Error: Failed to set Shared Memory size." << std::endl; - return; - } (*kernel_ptr) <<< dimGrid, dimBlock, smem_size, stream >>> (M, N, k_tile_size, cga_tile_shape, @@ -663,11 +684,12 @@ rht_gemm_ntt_w_sfc(int m, int n, SFC, mma, global_amax, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); } // this function is used to wrap the rht_gemm_ntt_w_sfc function //to transpose the input tensor A -template +template void rht_gemm_ttt_wrapper(int m, int n, TA const* A, @@ -690,7 +712,7 @@ rht_gemm_ttt_wrapper(int m, int n, // B: 16 x 16: row-major // C: n x m: row-major // SFC: n x (m/16): row-major - rht_gemm_ntt_w_sfc( + rht_gemm_ntt_w_sfc( n, m, A, B, C, SFC, global_amax, @@ -800,20 +822,23 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out } else if (m < 1024 || n < 1024) { k_tile_size = 512; } + TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kUseStochasticRounding, - detail::rht_gemm_ttt_wrapper( - /*m=*/m, - /*n=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*C=*/reinterpret_cast(output_t.dptr), - /*SFC=*/reinterpret_cast(scale_inv_t.dptr), - /*global_amax=*/reinterpret_cast(global_amax.dptr), - /*rng_state=*/rng_state, - /*sm_count=*/sm_count, - /*stream=*/stream, - /*k_tile_size=*/k_tile_size);); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size););); } } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index a3235e84f1..19fbe431aa 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -270,6 +270,20 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_config, const size_t num_tensors, cudaStream_t stream); +/*! \brief Casts grouped input tensor to quantized output tensors. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] outputs Output quantized tensors. + * \param[in] split_sections Split sections of the input tensor. + * \param[in] num_tensors Number of output tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 05541fe30c..112cb9b54d 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -86,6 +86,43 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp int random_sign_mask, int random_sign_mask_t, cudaStream_t stream); +/*! + * \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation. + * + * This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] outputs Array of output tensors. + * \param[in] hadamard_matrix Hadamard matrix to use for transformation. + * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. + * \param[in] num_tensors Number of output tensors, must be > 0. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_cast_fusion_columnwise( + const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix, + const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + +/*! + * \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation. + * + * This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] outputs Array of output tensors. + * \param[in] hadamard_matrix Hadamard matrix to use for transformation. + * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. + * \param[in] num_tensors Number of output tensors, must be > 0. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs, + const NVTETensor hadamard_matrix, + const size_t* split_sections, size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b2e04ba69f..19cb646be2 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigNVFP42DQuantization = 5, /*! Whether to enable stochastic rounding */ kNVTEQuantizationConfigStochasticRounding = 6, + /*! Whether to enable fast math operations with reduced accuracy. + * + * Optimizations are kernel-specific and they may be applied + * inconsistently between kernels. + */ + kNVTEQuantizationConfigUseFastMath = 7, kNVTEQuantizationConfigNumAttributes }; @@ -997,6 +1003,12 @@ class QuantizationConfigWrapper { &stochastic_rounding, sizeof(bool)); } + /*! \brief Set whether to enable fast math operations */ + void set_use_fast_math(bool use_fast_math) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath, + &use_fast_math, sizeof(bool)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 8d9563b789..4a140b4376 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -857,9 +857,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, // Write attribute size NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); - NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; - *size_written = attr_size; + if (size_written != nullptr) { + *size_written = attr_size; + } // Return immediately if buffer is not provided if (buf == nullptr) { @@ -889,6 +890,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size); break; + case kNVTEQuantizationConfigRNGState: + std::memcpy(buf, &config_.rng_state, attr_size); + break; + case kNVTEQuantizationConfigNVFP42DQuantization: + std::memcpy(buf, &config_.nvfp4_2d_quantization, attr_size); + break; + case kNVTEQuantizationConfigStochasticRounding: + std::memcpy(buf, &config_.stochastic_rounding, attr_size); + break; + case kNVTEQuantizationConfigUseFastMath: + std::memcpy(buf, &config_.use_fast_math, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -933,6 +946,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigStochasticRounding: std::memcpy(&config_.stochastic_rounding, buf, attr_size); break; + case kNVTEQuantizationConfigUseFastMath: + std::memcpy(&config_.use_fast_math, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ac541435c7..aa9d800c7b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -16,6 +16,7 @@ #include "../extensions.h" #include "common.h" +#include "common/util/system.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" @@ -709,179 +710,225 @@ std::tuple, std::vector, bool> bulk_alloc return retval; } -void split_quantize_nvfp4_impl(const TensorWrapper &input, - const std::vector &input_list, - std::vector &output_list, - const std::vector &split_sections, - const std::vector &quantizers) { - // Check tensor lists - const size_t num_tensors = split_sections.size(); - NVTE_CHECK(input_list.size() == num_tensors, "Expected ", num_tensors, " input tensors, but got ", - input_list.size(), "."); - NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors, - " output tensors, but got ", output_list.size(), "."); - NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors, - " NVFP4 quantizers, but got ", quantizers.size(), "."); +// Owns all allocations/wrappers backing quant_config_list[*].set_rng_state(...). +struct StochasticRngStateResources { + at::Tensor rng_states_tensor; // [2 * num_tensors], int64, CUDA + at::Tensor rng_states_tensor_colwise; // optional, same shape/dtype/device + std::vector te_rng_state_list; + std::vector te_rng_state_list_colwise; + + bool enabled{false}; + bool need_separate_rng_states{false}; + bool with_bulk_generate_rng_states{false}; +}; + +// Populates quant_config_list (+ optional colwise list) with rng_state pointers and stochastic flag. +static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( + size_t num_tensors, bool stochastic_rounding, bool with_bulk_generate_rng_states, + bool need_separate_rng_states, + std::vector &quant_config_list_rowwise, + std::vector &quant_config_list_colwise) { + // the return object will be used to keep rng states alive + StochasticRngStateResources res; + res.enabled = stochastic_rounding; + res.need_separate_rng_states = need_separate_rng_states; + res.with_bulk_generate_rng_states = with_bulk_generate_rng_states; + + if (!stochastic_rounding) return res; + + // Basic sanity: caller usually pre-sizes these to num_tensors. + TORCH_CHECK(quant_config_list_rowwise.size() == num_tensors, + "quant_config_list_rowwise must be sized to num_tensors"); + if (need_separate_rng_states) { + TORCH_CHECK(quant_config_list_colwise.size() == num_tensors, + "quant_config_list_colwise must be sized to num_tensors when " + "need_separate_rng_states=true"); + } - // Trivial cases - if (num_tensors == 0) { - return; + const size_t rng_elts_per_thread = + res.with_bulk_generate_rng_states ? (1024 * num_tensors) : 1024; + + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + res.rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); + if (need_separate_rng_states) { + res.rng_states_tensor_colwise = torch::empty({static_cast(2 * num_tensors)}, opts); } - if (input.numel() == 0) { - for (const auto &tensor : input_list) { - NVTE_CHECK(tensor.numel() == 0, - "Input tensor has zero elements but got split with non-zero elements"); + + res.te_rng_state_list.reserve(num_tensors); + if (need_separate_rng_states) res.te_rng_state_list_colwise.reserve(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // Rowwise RNG state + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; + philox_unpack(philox_args, rng_state_ptr); + + res.te_rng_state_list.push_back(makeTransformerEngineTensor( + static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); + quant_config_list_rowwise[i].set_stochastic_rounding(true); + + // Colwise RNG state (only if you truly need a different sequence) + if (need_separate_rng_states) { + // re-initialize philox_args for colwise RNG state + at::PhiloxCudaState philox_args_col = init_philox_state(gen, rng_elts_per_thread); + int64_t *rng_state_ptr_colwise = + static_cast(res.rng_states_tensor_colwise.data_ptr()) + i * 2; + + philox_unpack(philox_args_col, rng_state_ptr_colwise); + + res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( + static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); + quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); + quant_config_list_colwise[i].set_stochastic_rounding(true); } - return; - } - // Assume all quantizers have identical config - const auto &quantizer = *quantizers.front(); - NVTE_CHECK(!quantizer.with_2d_quantization, - "NVFP4 split-quantize does not support 2D quantization"); - NVTE_CHECK(!quantizer.with_amax_reduction, - "NVFP4 split-quantize does not support amax reduction"); + // break the loop if we are using bulk generate rng states + if (res.with_bulk_generate_rng_states) break; + } - // Check input tensor shape - const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - NVTE_CHECK(input_last_dim % 128 == 0, - "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + return res; +} - // CUDA stream - auto stream = at::cuda::getCurrentCUDAStream(); +// Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) +void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers, + cudaStream_t stream) { + const size_t num_tensors = split_sections.size(); + const auto &quantizer = *quantizers.front(); - // Objects for TE C API std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; - std::vector quant_config_list; for (size_t i = 0; i < num_tensors; ++i) { nvte_tensor_input_list.push_back(input_list[i].data()); nvte_tensor_output_list.push_back(output_list[i].data()); + } + + // trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance + bool all_aligned_token_dim = + std::all_of(split_sections.begin(), split_sections.end(), + [](size_t split_section) { return split_section % 128 == 0; }); + + // in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice + // so that rowwise and colwise will have different random numbers + bool need_separate_rng_states = + (!all_aligned_token_dim) && quantizer.rowwise_usage && quantizer.columnwise_usage; + + // Objects for TE C API + std::vector quant_config_list; + std::vector quant_config_list_colwise; + for (size_t i = 0; i < num_tensors; ++i) { quant_config_list.emplace_back(QuantizationConfigWrapper()); + quant_config_list_colwise.emplace_back(QuantizationConfigWrapper()); } + // this is true because we have already built grouped kernels for rowwise and colwise quantization with RHT + bool with_bulk_generate_rng_states = true; + // Stochastic rounding - // When both rowwise and columnwise quantization are used, - // we need separate RNG states for each to ensure they use different random numbers. - std::vector te_rng_state_list; - std::vector te_rng_state_columnwise_list; - std::vector quant_config_columnwise_list; - at::Tensor rng_states_tensor; - at::Tensor rng_states_columnwise_tensor; - const bool need_separate_columnwise_rng = - quantizer.stochastic_rounding && quantizer.with_rht && quantizer.columnwise_usage; - - if (quantizer.stochastic_rounding) { - // TODO(zhongbo): remove the for loop of generating rng states with a single call - // with rng_elts_per_thread = 1024 * num_tensors - // Change to the bulk generate rng states api when grouped quantize is available - const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened - auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); - - // Allocate columnwise RNG resources when separate RNG is needed - if (need_separate_columnwise_rng) { - rng_states_columnwise_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_columnwise_list.emplace_back(QuantizationConfigWrapper()); - } + bool need_stochastic_rounding = quantizer.stochastic_rounding; + auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( + num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, + need_separate_rng_states, quant_config_list, quant_config_list_colwise); + + // Enable NVFP4 kernels to use math operations that sacrifice + // accuracy for performance. These optimizations are experimental + // and inconsistently implemented. + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + for (auto &config : quant_config_list) { + config.set_use_fast_math(true); } - for (size_t i = 0; i < num_tensors; ++i) { - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - // Generate RNG state for rowwise quantization - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - int64_t *rng_state_ptr = static_cast(rng_states_tensor.data_ptr()) + i * 2; - philox_unpack(philox_args, rng_state_ptr); - te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); - quant_config_list[i].set_rng_state(te_rng_state_list[i].data()); - quant_config_list[i].set_stochastic_rounding(true); - - // Generate separate RNG state for columnwise quantization - if (need_separate_columnwise_rng) { - at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread); - int64_t *rng_state_columnwise_ptr = - static_cast(rng_states_columnwise_tensor.data_ptr()) + i * 2; - philox_unpack(philox_args_columnwise, rng_state_columnwise_ptr); - te_rng_state_columnwise_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_columnwise_ptr), std::vector{2}, DType::kInt64)); - quant_config_columnwise_list[i].set_rng_state(te_rng_state_columnwise_list[i].data()); - quant_config_columnwise_list[i].set_stochastic_rounding(true); - } + for (auto &config : quant_config_list_colwise) { + config.set_use_fast_math(true); } } - // Perform multi-tensor quantization - if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data - // Check that config is supported - NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); - - // Compute amaxes - if (quantizer.with_post_rht_amax) { - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for RHT(input.t) - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_hadamard_transform_amax( - input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream); - }); - } else { - // RHT is enabled, but amax is pre-RHT amax - NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax"); - } + auto &quant_config_list_colwise_to_use = + need_separate_rng_states ? quant_config_list_colwise : quant_config_list; - // Check that RHT matrix is available - NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0, - "RHT matrix is not available."); - auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); + // Compute amaxes + if (quantizer.with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + nvte_group_hadamard_transform_amax( + input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream); + } else { + // RHT is enabled, but amax is pre-RHT amax + NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax"); + } - // Quantize tensors individually - NVTE_SCOPED_GIL_RELEASE({ - for (size_t i = 0; i < num_tensors; i++) { - if (input_list[i].numel() == 0) { - continue; // Skip tensors with no elements - } + // Check that RHT matrix is available + NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0, + "RHT matrix is not available."); + auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); - // Direct NVFP4 quantization for row-wise data - if (quantizer.rowwise_usage) { - auto out_rowwise_data = output_list[i].get_rowwise_data(); - auto out_rowwise_scale_inv = output_list[i].get_rowwise_scale_inv(); - auto out_rowwise_amax = output_list[i].get_amax(); - TensorWrapper out_rowwise(output_list[i].scaling_mode()); - out_rowwise.set_rowwise_data(out_rowwise_data.data_ptr, - static_cast(out_rowwise_data.dtype), - out_rowwise_data.shape); - out_rowwise.set_rowwise_scale_inv(out_rowwise_scale_inv.data_ptr, - static_cast(out_rowwise_scale_inv.dtype), - out_rowwise_scale_inv.shape); - out_rowwise.set_amax(out_rowwise_amax.data_ptr, - static_cast(out_rowwise_amax.dtype), out_rowwise_amax.shape); - nvte_quantize_v2(input_list[i].data(), out_rowwise.data(), quant_config_list[i], stream); + if (all_aligned_token_dim) { + // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose + nvte_group_hadamard_transform_cast_fusion( + input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + } else { + // Separate quantization for rowwise usage and columnwise usage + // Rowwise quantization fusion with grouped version + if (quantizer.rowwise_usage) { + std::vector out_identity_list; + std::vector nvte_tensor_out_identity_list; + for (size_t i = 0; i < num_tensors; i++) { + bool is_empty_split = input_list[i].numel() == 0; + TensorWrapper out_identity(output_list[i].scaling_mode()); + auto out_identity_data = output_list[i].get_rowwise_data(); + auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv(); + auto out_identity_amax = output_list[i].get_amax(); + if (!is_empty_split) { + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, + static_cast(out_identity_amax.dtype), + out_identity_amax.shape); } + out_identity_list.emplace_back(std::move(out_identity)); + nvte_tensor_out_identity_list.push_back(out_identity_list.back().data()); + } + nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_out_identity_list.data(), + split_sections.data(), num_tensors, quant_config_list[0], + stream); + } - // RHT + NVFP4 quantize for column-wise data - if (quantizer.columnwise_usage) { - // Get the output column-wise data, scale_inv, and amax - auto out_columnwise_data = output_list[i].get_columnwise_data(); - auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); - auto out_columnwise_amax = output_list[i].get_columnwise_amax(); - - // Flatten column-wise data to 2D + // Columnwise RHT quantization fusion with grouped version + if (quantizer.columnwise_usage) { + std::vector out_transpose_list; + std::vector nvte_tensor_out_transpose_list; + for (size_t i = 0; i < num_tensors; i++) { + bool is_empty_split = input_list[i].numel() == 0; + auto out_columnwise_data = output_list[i].get_columnwise_data(); + auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); + auto out_columnwise_amax = output_list[i].get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout. + TensorWrapper out_transpose(output_list[i].scaling_mode()); + if (!is_empty_split) { auto colwise_data_shape = out_columnwise_data.shape; std::vector colwise_data_shape_2d; colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; + for (size_t j = 1; j < colwise_data_shape.ndim; ++j) { + last_dim *= colwise_data_shape.data[j]; } colwise_data_shape_2d.push_back(last_dim); - // Create a wrapper for the columnwise output, as the rowwise output. - // The reason is due to the input `rht_output_t` is already in the transposed layout. - // Thus, we only need a rowwise quantization to generate the columnwise output. - TensorWrapper out_transpose(output_list[i].scaling_mode()); out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, static_cast(out_columnwise_data.dtype), colwise_data_shape_2d); @@ -891,53 +938,151 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, out_transpose.set_amax(out_columnwise_amax.data_ptr, static_cast(out_columnwise_amax.dtype), out_columnwise_amax.shape); - - // RHT + NVFP4 quantize kernel - // Use separate RNG state for columnwise to ensure different random numbers than rowwise - auto &columnwise_quant_config = - need_separate_columnwise_rng ? quant_config_columnwise_list[i] : quant_config_list[i]; - nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(), - rht_matrix_nvte.data(), - columnwise_quant_config, stream); } + out_transpose_list.emplace_back(std::move(out_transpose)); + nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); } - }); - - } else { // NVFP4 quantize - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for input too - // Columnwise amax will be filled with a fused D2D copy from rowwise amax - // Note that the multi compute amax API expects rowwise amax pointer to be not null - // So we need to set the pointer accordingly to make colwise-only quantization work - std::vector orig_amax_ptr_list; - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + nvte_group_hadamard_transform_cast_fusion_columnwise( + input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, + quant_config_list_colwise_to_use[0], stream); } - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, stream); - }); - for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } +} + +void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers, + cudaStream_t stream) { + const size_t num_tensors = input_list.size(); + const auto &quantizer = *quantizers.front(); + + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + + // In this case without RHT, the rowwise and colwise quantization are fused + // we don't need separate rng states for rowwise and colwise + bool need_separate_rng_states = false; + + // Objects for TE C API + std::vector quant_config_list; + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + } + + // TODO: this is only true because the non-RHT path doesn't have grouped kernels yet, which we can be optimized + // so that we can generate all rng states at once + bool with_bulk_generate_rng_states = false; + + bool need_stochastic_rounding = quantizer.stochastic_rounding; + + // place holder for colwise rng states, which are not needed in this case + std::vector dummy_quant_config_list_colwise; + + auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( + num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, + need_separate_rng_states, quant_config_list, + dummy_quant_config_list_colwise); // colwise rng states are not needed in this case + + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for input too + // Columnwise amax will be filled with a fused D2D copy from rowwise amax + // Note that the multi compute amax API expects rowwise amax pointer to be not null + // So we need to set the pointer accordingly to make colwise-only quantization work + std::vector orig_amax_ptr_list; + for (size_t i = 0; i < num_tensors; i++) { + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + } + nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, stream); + for (size_t i = 0; i < num_tensors; i++) { + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } + + // Quantize tensors individually + for (size_t i = 0; i < num_tensors; i++) { + // skip this round if input is empty + if (input_list[i].numel() == 0) { + continue; } + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } +} - // Quantize tensors individually - NVTE_SCOPED_GIL_RELEASE({ - for (size_t i = 0; i < num_tensors; i++) { - // skip this round if input is empty - if (input_list[i].numel() == 0) { - continue; - } - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); - } - }); +void split_quantize_nvfp4_impl(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers) { + // Check tensor lists + const size_t num_tensors = split_sections.size(); + NVTE_CHECK(input_list.size() == num_tensors, "Expected ", num_tensors, " input tensors, but got ", + input_list.size(), "."); + NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors, + " output tensors, but got ", output_list.size(), "."); + NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors, + " NVFP4 quantizers, but got ", quantizers.size(), "."); + + // sanity check all the quantizers have the same scaling mode + bool all_same_scaling_mode = + std::all_of(quantizers.begin(), quantizers.end(), [&](const NVFP4Quantizer *quantizer) { + return quantizer->get_scaling_mode() == quantizers.front()->get_scaling_mode(); + }); + NVTE_CHECK(all_same_scaling_mode, "All quantizers must have the same scaling mode"); + + // Trivial cases + if (num_tensors == 0) { + return; + } + if (input.numel() == 0) { + for (const auto &tensor : input_list) { + NVTE_CHECK(tensor.numel() == 0, + "Input tensor has zero elements but got split with non-zero elements"); + } + return; } + + // Assume all quantizers have identical config + const auto &quantizer = *quantizers.front(); + NVTE_CHECK(!quantizer.with_2d_quantization, + "NVFP4 split-quantize does not support 2D quantization"); + NVTE_CHECK(!quantizer.with_amax_reduction, + "NVFP4 split-quantize does not support amax reduction"); + + // Check input tensor shape + const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; + NVTE_CHECK(input_last_dim % 128 == 0, + "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + + // CUDA stream + auto stream = at::cuda::getCurrentCUDAStream(); + + // Perform multi-tensor quantization + NVTE_SCOPED_GIL_RELEASE({ + if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data + // Check that config is supported + NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); + // Fuse the rowwise and colwise into one when the kernel is ready + split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, + quantizers, stream); + } else { // NVFP4 quantize + // Fuse the rowwise and colwise into one when the kernel is ready + split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, + stream); + } + }); } } // namespace diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c73c09b317..fd748d1b21 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1501,7 +1501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } } - // Restriction for the RHT cast fusion kernel. + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fbe2ee6d1c..3f5995230c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -120,7 +120,7 @@ def get_align_size_for_quantization(recipe: Recipe) -> int: if recipe.mxfp8(): return 32 if recipe.nvfp4(): - return 64 + return 128 return 16