-
Notifications
You must be signed in to change notification settings - Fork 30
enable blockwise FP8 quantization on rocm #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,7 +200,8 @@ list(APPEND transformer_engine_cuda_sources | |
| transpose/cast_transpose_fusion.cu | ||
| transpose/transpose_fusion.cu | ||
| transpose/multi_cast_transpose.cu | ||
| transpose/quantize_transpose_vector_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| transpose/quantize_transpose_square_blockwise.cu | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should stay in transformer_engine_cuda_arch_specific_sources |
||
| transpose/swap_first_dims.cu | ||
| dropout/dropout.cu | ||
| fused_attn/flash_attn.cu | ||
|
|
@@ -233,7 +234,6 @@ list(APPEND transformer_engine_cuda_sources | |
| comm_gemm_overlap/userbuffers/userbuffers.cu) | ||
|
|
||
| set(cuda_only_cuda_sources | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| fused_attn/fused_attn_f16_max512_seqlen.cu | ||
| fused_attn/fused_attn_f16_arbitrary_seqlen.cu | ||
| fused_attn/fused_attn_fp8.cu | ||
|
|
@@ -257,7 +257,6 @@ list(APPEND transformer_engine_cuda_arch_specific_sources | |
| multi_tensor/compute_scale.cu | ||
| recipe/mxfp8_scaling.cu | ||
| recipe/nvfp4.cu | ||
| transpose/quantize_transpose_square_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise_fp4.cu) | ||
|
|
||
| set(cuda_only_cuda_arch_specific_sources | ||
|
|
@@ -267,8 +266,7 @@ set(cuda_only_cuda_arch_specific_sources | |
| hadamard_transform/hadamard_transform_cast_fusion.cu | ||
| hadamard_transform/group_hadamard_transform_cast_fusion.cu | ||
| hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu | ||
| hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu | ||
| transpose/quantize_transpose_square_blockwise.cu) | ||
| hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu) | ||
|
|
||
| # Compiling the files with the worst compilation time first to hopefully overlap | ||
| # better with the faster-compiling cpp files | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -636,8 +636,8 @@ struct TypeExtrema<fp8e4m3> { | |
| static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; | ||
| #else | ||
| static float max; | ||
| static float max_inverse; | ||
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change necessary? fp8e4m3 max depends on the device type on AMD.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The The fix uses
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed |
||
| #endif | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,30 +5,50 @@ | |
| ************************************************************************/ | ||
|
|
||
| #include <cuda.h> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cudaTypedefs.h> | ||
| #endif | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <cfloat> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cuda/barrier> | ||
| #endif | ||
|
|
||
| #include "common/common.h" | ||
| #include "common/recipe/recipe_common.cuh" | ||
| #include "common/util/cuda_runtime.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "common/util/ptx.cuh" | ||
| #endif | ||
|
asdfvg123 marked this conversation as resolved.
Outdated
|
||
| #include "common/utils.cuh" | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \ | ||
| (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) | ||
| #define TMA_HW_SUPPORTED | ||
| #endif | ||
| #endif | ||
|
|
||
| namespace transformer_engine { | ||
| namespace { | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ROCm should not use it. See how *_sync calls are guarded in other places |
||
| #else | ||
| using WarpSyncMask = unsigned; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu; | ||
| #endif | ||
|
|
||
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is platform dependent. |
||
| #else | ||
| constexpr size_t kThreadsPerWarp = 32; | ||
| #endif | ||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t BLOCK_TILE_DIM = 128; | ||
| constexpr size_t WARP_TILE_DIM_X = 32; | ||
|
|
@@ -62,6 +82,7 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP | |
|
|
||
| #define MIN(a, b) (a < b ? a : b) | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED | ||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
| block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, | ||
|
|
@@ -133,7 +154,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) | |
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| // broadcast the amax to all threads in a warp from the lane 0 | ||
| constexpr int lane_zero = 0; | ||
| warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); | ||
| warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); | ||
|
|
||
| // reduce warp_tile_amax across multiple warps in a thread block using shared mem | ||
| if (tid_in_warp == 0) { | ||
|
|
@@ -247,6 +268,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) | |
| #endif | ||
| } | ||
| } | ||
| #endif // TMA_HW_SUPPORTED | ||
|
|
||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( | ||
|
|
@@ -360,7 +382,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
| // broadcast the amax to all threads in a warp from the lane 0 | ||
| constexpr int lane_zero = 0; | ||
| warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); | ||
| warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); | ||
|
|
||
| // reduce warp_tile_amax across multiple warps in a thread block using shared mem | ||
| if (tid_in_warp == 0) { | ||
|
|
@@ -456,6 +478,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED | ||
|
asdfvg123 marked this conversation as resolved.
Outdated
|
||
| template <typename OutputType> | ||
| CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { | ||
| CUtensorMapDataType dataType; | ||
|
|
@@ -473,6 +496,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size | |
| /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); | ||
| return tensor_map_output_trans; | ||
| } | ||
| #endif // TMA_HW_SUPPORTED | ||
|
|
||
| } // namespace | ||
| } // namespace transformer_engine | ||
|
|
@@ -546,6 +570,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| const bool full_tile = | ||
| row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED | ||
| if (full_tile) { | ||
| CUtensorMap tensor_map_output_trans; | ||
| if (return_transpose) { | ||
|
|
@@ -561,7 +586,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, | ||
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| tensor_map_output_trans, pow_2_scale, noop_ptr); | ||
| } else { | ||
| } else | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's avoid splitting up the } else { line here. We can add another macro guard instead if needed. |
||
| #else | ||
| (void)full_tile; | ||
| #endif | ||
| { | ||
| block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, | ||
| OutputType> | ||
| <<<grid, THREADS_PER_BLOCK, 0, stream>>>( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,13 +5,17 @@ | |
| ************************************************************************/ | ||
|
|
||
| #include <cuda.h> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cudaTypedefs.h> | ||
| #endif | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <cfloat> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cuda/barrier> | ||
| #endif | ||
| #include <utility> | ||
|
|
||
| #include "common/common.h" | ||
|
|
@@ -26,6 +30,14 @@ namespace { | |
| using transformer_engine::detail::FP8BlockwiseColumnwiseOption; | ||
| using transformer_engine::detail::FP8BlockwiseRowwiseOption; | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; | ||
| #else | ||
| using WarpSyncMask = unsigned; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu; | ||
| #endif | ||
|
|
||
| // clang-format off | ||
| /* | ||
|
|
||
|
|
@@ -145,14 +157,23 @@ Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast | |
| */ | ||
| // clang-format on | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; | ||
| #else | ||
| constexpr size_t kThreadsPerWarp = 32; | ||
| #endif | ||
|
|
||
| // Hyperparameters for performance tuning | ||
| constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization | ||
| constexpr int kNVecIn = 8; // The number of elements each LDG touches | ||
| constexpr int kNVecOut = 16; // The number of elements each STG touches | ||
| constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The kernel expects 8 waves / block , so I increased the number of threads |
||
| #else | ||
| constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total | ||
| #endif | ||
|
|
||
| // Auto-calculated constants, do not modify directly) | ||
| static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); | ||
|
|
@@ -259,7 +280,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo | |
| // the first thread to do the reduction. | ||
| const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; | ||
| // This mask represents which threads should do the reduction together. | ||
| const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; | ||
| const WarpSyncMask mask = ((WarpSyncMask{1} << kNumThreadsStore) - 1) << src_lane; | ||
| const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; | ||
| #pragma unroll | ||
| for (int iter = 0; iter < num_iterations; ++iter) { | ||
|
|
@@ -350,7 +371,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo | |
| // the first thread to do the reduction. | ||
| const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; | ||
| // This mask represents which threads should do the reduction together. | ||
| const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; | ||
| const WarpSyncMask mask = ((WarpSyncMask{1} << kNumThreadsStore) - 1) << src_lane; | ||
| const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; | ||
| #pragma unroll | ||
| for (int iter = 0; iter < num_iterations; ++iter) { | ||
|
|
@@ -474,7 +495,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo | |
| const bool is_src_lane = thr_idx_in_warp == 0; | ||
| amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
| constexpr int lane_zero = 0; | ||
| amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero); | ||
| amax = __shfl_sync(kFullWarpMask, amax, lane_zero); | ||
| // Step 3.4: Compute scale | ||
| CType scale; | ||
| scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this be always
Trueon ROCm TE?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to (9, 5)