-
Notifications
You must be signed in to change notification settings - Fork 30
enable blockwise FP8 quantization on rocm #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copyright |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,7 +200,8 @@ list(APPEND transformer_engine_cuda_sources | |
| transpose/cast_transpose_fusion.cu | ||
| transpose/transpose_fusion.cu | ||
| transpose/multi_cast_transpose.cu | ||
| transpose/quantize_transpose_vector_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| transpose/quantize_transpose_square_blockwise.cu | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should stay in transformer_engine_cuda_arch_specific_sources |
||
| transpose/swap_first_dims.cu | ||
| dropout/dropout.cu | ||
| fused_attn/flash_attn.cu | ||
|
|
@@ -233,7 +234,6 @@ list(APPEND transformer_engine_cuda_sources | |
| comm_gemm_overlap/userbuffers/userbuffers.cu) | ||
|
|
||
| set(cuda_only_cuda_sources | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| fused_attn/fused_attn_f16_max512_seqlen.cu | ||
| fused_attn/fused_attn_f16_arbitrary_seqlen.cu | ||
| fused_attn/fused_attn_fp8.cu | ||
|
|
@@ -257,7 +257,6 @@ list(APPEND transformer_engine_cuda_arch_specific_sources | |
| multi_tensor/compute_scale.cu | ||
| recipe/mxfp8_scaling.cu | ||
| recipe/nvfp4.cu | ||
| transpose/quantize_transpose_square_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise_fp4.cu) | ||
|
|
||
| set(cuda_only_cuda_arch_specific_sources | ||
|
|
@@ -267,8 +266,7 @@ set(cuda_only_cuda_arch_specific_sources | |
| hadamard_transform/hadamard_transform_cast_fusion.cu | ||
| hadamard_transform/group_hadamard_transform_cast_fusion.cu | ||
| hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu | ||
| hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu | ||
| transpose/quantize_transpose_square_blockwise.cu) | ||
| hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu) | ||
|
|
||
| # Compiling the files with the worst compilation time first to hopefully overlap | ||
| # better with the faster-compiling cpp files | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| /************************************************************************* | ||
| * This file was modified for portability to AMDGPU | ||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
|
|
@@ -18,17 +20,31 @@ | |
| #include "common/util/ptx.cuh" | ||
| #include "common/utils.cuh" | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \ | ||
| (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) | ||
| #define TMA_HW_SUPPORTED | ||
| #endif | ||
| #endif | ||
|
|
||
| namespace transformer_engine { | ||
| namespace { | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ROCm should not use it. See how *_sync calls are guarded in other places |
||
| #else | ||
| using WarpSyncMask = unsigned; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu; | ||
| #endif | ||
|
|
||
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is platform dependent. |
||
| #else | ||
| constexpr size_t kThreadsPerWarp = 32; | ||
| #endif | ||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t BLOCK_TILE_DIM = 128; | ||
| constexpr size_t WARP_TILE_DIM_X = 32; | ||
|
|
@@ -40,8 +56,12 @@ constexpr size_t BLOCK_TILE_DIM = 128; | |
| constexpr size_t WARP_TILE_DIM_X = 64; | ||
| constexpr size_t WARP_TILE_DIM_Y = 32; | ||
| constexpr size_t THREAD_TILE_DIM_X = 8; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t THREAD_TILE_DIM_Y = 4; | ||
| #else | ||
| constexpr size_t THREAD_TILE_DIM_Y = 8; | ||
| #endif | ||
| #endif | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t NUM_BYTES_PER_BANK = 4; | ||
|
|
@@ -62,6 +82,16 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP | |
|
|
||
| #define MIN(a, b) (a < b ? a : b) | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| __device__ __forceinline__ float blockwise_warp_reduce_max(float val) { | ||
| #pragma unroll | ||
| for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) | ||
| val = fmaxf(val, __shfl_down(val, delta, kThreadsPerWarp)); | ||
| return val; | ||
| } | ||
| #endif | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
| block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, | ||
|
|
@@ -130,10 +160,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) | |
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole this code is under #ifndef HIP_PLATFORM_AMD |
||
| warp_tile_amax = blockwise_warp_reduce_max(amax); | ||
| #else | ||
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| #endif | ||
| // broadcast the amax to all threads in a warp from the lane 0 | ||
| constexpr int lane_zero = 0; | ||
| warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); | ||
| warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); | ||
|
|
||
| // reduce warp_tile_amax across multiple warps in a thread block using shared mem | ||
| if (tid_in_warp == 0) { | ||
|
|
@@ -247,6 +281,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) | |
| #endif | ||
| } | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( | ||
|
|
@@ -357,10 +392,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| warp_tile_amax = blockwise_warp_reduce_max(amax); | ||
| #else | ||
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
| #endif | ||
| // broadcast the amax to all threads in a warp from the lane 0 | ||
| constexpr int lane_zero = 0; | ||
| warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); | ||
| warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); | ||
|
|
||
| // reduce warp_tile_amax across multiple warps in a thread block using shared mem | ||
| if (tid_in_warp == 0) { | ||
|
|
@@ -456,6 +495,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| template <typename OutputType> | ||
| CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { | ||
| CUtensorMapDataType dataType; | ||
|
|
@@ -473,6 +513,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size | |
| /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); | ||
| return tensor_map_output_trans; | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| } // namespace | ||
| } // namespace transformer_engine | ||
|
|
@@ -543,9 +584,10 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| return_transpose, kReturnTranspose, | ||
|
|
||
| dim3 grid(num_blocks_x, num_blocks_y, 1); | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| const bool full_tile = | ||
| row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; | ||
|
|
||
| if (full_tile) { | ||
| CUtensorMap tensor_map_output_trans; | ||
| if (return_transpose) { | ||
|
|
@@ -573,6 +615,18 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| pow_2_scale, noop_ptr); | ||
| } // full-tile | ||
| #else | ||
| block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, | ||
| OutputType> | ||
| <<<grid, THREADS_PER_BLOCK, 0, stream>>>( | ||
| reinterpret_cast<const InputType*>(input.dptr), | ||
| reinterpret_cast<OutputType*>(output.dptr), | ||
| reinterpret_cast<OutputType*>(output_t.dptr), | ||
| reinterpret_cast<float*>(scale_inv.dptr), | ||
| reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, | ||
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| pow_2_scale, noop_ptr); | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
| ) // return_transpose | ||
| ) // OutputType | ||
| ) // InputType | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.