From 8335488412ee193625cf5039fabe237cc9b21c94 Mon Sep 17 00:00:00 2001 From: asdfvg123 Date: Mon, 1 Jun 2026 20:25:43 +0000 Subject: [PATCH 1/3] enable blockwise FP8 quantization on rocm --- tests/pytorch/test_float8blockwisetensor.py | 5 ++- transformer_engine/common/CMakeLists.txt | 8 ++--- .../common/cast/dispatch/quantize.cuh | 4 --- transformer_engine/common/common.h | 4 +-- .../quantize_transpose_square_blockwise.cu | 35 +++++++++++++++++-- .../quantize_transpose_vector_blockwise.cu | 27 ++++++++++++-- transformer_engine/pytorch/quantization.py | 2 +- 7 files changed, 66 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 5fc6aa51c..4d24fbb1e 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -45,7 +45,10 @@ def _to_list(x: Union[Iterable, Any]) -> List: DimsType = Union[Iterable[int], int] # TODO replace with call to fp8.py when recipe added. -recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) +if IS_HIP_EXTENSION: + recipe_available = get_device_compute_capability() >= (9, 0) +else: + recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 02eaaea93..bd584ac6c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -200,7 +200,8 @@ list(APPEND transformer_engine_cuda_sources transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu - transpose/quantize_transpose_vector_blockwise.cu #CUDA-only + transpose/quantize_transpose_vector_blockwise.cu + transpose/quantize_transpose_square_blockwise.cu transpose/swap_first_dims.cu dropout/dropout.cu fused_attn/flash_attn.cu @@ -233,7 +234,6 @@ list(APPEND transformer_engine_cuda_sources comm_gemm_overlap/userbuffers/userbuffers.cu) set(cuda_only_cuda_sources - transpose/quantize_transpose_vector_blockwise.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_fp8.cu @@ -257,7 +257,6 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu - transpose/quantize_transpose_square_blockwise.cu #CUDA-only transpose/quantize_transpose_vector_blockwise_fp4.cu) set(cuda_only_cuda_arch_specific_sources @@ -267,8 +266,7 @@ set(cuda_only_cuda_arch_specific_sources hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu - hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu - transpose/quantize_transpose_square_blockwise.cu) + hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu) # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 579caee06..1b52a7c68 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -164,7 +164,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, #endif break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); @@ -196,7 +195,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } -#endif//#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } @@ -317,7 +315,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens #endif break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT), @@ -351,7 +348,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } -#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 2273253ec..87af7c3f8 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -636,8 +636,8 @@ struct TypeExtrema { static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; static constexpr float max_inverse = 1.0 / max; #else - static float max; - static float max_inverse; + static constexpr float max = 448.0f; + static constexpr float max_inverse = 1.0 / max; #endif }; diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 3a8536587..3dc9b250d 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -5,30 +5,50 @@ ************************************************************************/ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/util/cuda_runtime.h" +#ifndef __HIP_PLATFORM_AMD__ #include "common/util/ptx.cuh" +#endif #include "common/utils.cuh" +#ifndef __HIP_PLATFORM_AMD__ #if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \ (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) #define TMA_HW_SUPPORTED #endif +#endif namespace transformer_engine { namespace { +#ifdef __HIP_PLATFORM_AMD__ +using WarpSyncMask = uint64_t; +constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; +#else +using WarpSyncMask = unsigned; +constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu; +#endif + // const values configuration +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t kThreadsPerWarp = 64; +#else constexpr size_t kThreadsPerWarp = 32; +#endif #ifdef TMA_HW_SUPPORTED constexpr size_t BLOCK_TILE_DIM = 128; constexpr size_t WARP_TILE_DIM_X = 32; @@ -62,6 +82,7 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP #define MIN(a, b) (a < b ? a : b) +#ifdef TMA_HW_SUPPORTED template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, @@ -133,7 +154,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) warp_tile_amax = warp_reduce_max(amax); // broadcast the amax to all threads in a warp from the lane 0 constexpr int lane_zero = 0; - warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); // reduce warp_tile_amax across multiple warps in a thread block using shared mem if (tid_in_warp == 0) { @@ -247,6 +268,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #endif } } +#endif // TMA_HW_SUPPORTED template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( @@ -360,7 +382,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose warp_tile_amax = warp_reduce_max(amax); // broadcast the amax to all threads in a warp from the lane 0 constexpr int lane_zero = 0; - warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); // reduce warp_tile_amax across multiple warps in a thread block using shared mem if (tid_in_warp == 0) { @@ -456,6 +478,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } +#ifdef TMA_HW_SUPPORTED template CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { CUtensorMapDataType dataType; @@ -473,6 +496,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); return tensor_map_output_trans; } +#endif // TMA_HW_SUPPORTED } // namespace } // namespace transformer_engine @@ -546,6 +570,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; +#ifdef TMA_HW_SUPPORTED if (full_tile) { CUtensorMap tensor_map_output_trans; if (return_transpose) { @@ -561,7 +586,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale, noop_ptr); - } else { + } else +#else + (void)full_tile; +#endif + { block_scaled_cast_transpose_kernel_notaligned <<>>( diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index df869b433..272220b7a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -5,13 +5,17 @@ ************************************************************************/ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include #include #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include "common/common.h" @@ -26,6 +30,14 @@ namespace { using transformer_engine::detail::FP8BlockwiseColumnwiseOption; using transformer_engine::detail::FP8BlockwiseRowwiseOption; +#ifdef __HIP_PLATFORM_AMD__ +using WarpSyncMask = uint64_t; +constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; +#else +using WarpSyncMask = unsigned; +constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu; +#endif + // clang-format off /* @@ -145,14 +157,23 @@ Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast */ // clang-format on +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t kThreadsPerWarp = 64; +#else constexpr size_t kThreadsPerWarp = 32; +#endif // Hyperparameters for performance tuning constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches + +#ifdef __HIP_PLATFORM_AMD__ +constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total +#else constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total +#endif // Auto-calculated constants, do not modify directly) static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); @@ -259,7 +280,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // the first thread to do the reduction. const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; // This mask represents which threads should do the reduction together. - const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const WarpSyncMask mask = ((WarpSyncMask{1} << kNumThreadsStore) - 1) << src_lane; const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; #pragma unroll for (int iter = 0; iter < num_iterations; ++iter) { @@ -350,7 +371,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // the first thread to do the reduction. const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; // This mask represents which threads should do the reduction together. - const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const WarpSyncMask mask = ((WarpSyncMask{1} << kNumThreadsStore) - 1) << src_lane; const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; #pragma unroll for (int iter = 0; iter < num_iterations; ++iter) { @@ -474,7 +495,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const bool is_src_lane = thr_idx_in_warp == 0; amax = warp_reduce_max(amax); constexpr int lane_zero = 0; - amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero); + amax = __shfl_sync(kFullWarpMask, amax, lane_zero); // Step 3.4: Compute scale CType scale; scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index dcd12b7a0..6cae60ccd 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -103,7 +103,7 @@ def check_nvfp4_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if IS_HIP_EXTENSION: - return False, "FP8 block scaled gemm not yet supported for ROCm" + return True, "" if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" return ( From 622630112d4f612e4166b7b1dff59fbcf08e9b3e Mon Sep 17 00:00:00 2001 From: asdfvg123 Date: Thu, 4 Jun 2026 17:35:00 +0000 Subject: [PATCH 2/3] enable blockwise FP8 C++ tests on ROCm, fix wave64 bugs, remove redundant HIP guards, revert unnecessary common.h change --- ci/pytorch.sh | 1 + tests/cpp/operator/CMakeLists.txt | 3 +- .../cpp/operator/test_cast_float8blockwise.cu | 16 ++++++---- tests/pytorch/test_float8blockwisetensor.py | 2 +- transformer_engine/common/common.h | 4 +-- .../quantize_transpose_square_blockwise.cu | 29 +++++++++++++++---- .../quantize_transpose_vector_blockwise.cu | 6 ++-- 7 files changed, 41 insertions(+), 20 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 32fbf02f8..1cd324e36 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -51,6 +51,7 @@ run_test_config(){ run_default_fa 1 test_deferred_init.py run_default_fa 1 test_quantized_tensor.py run_default_fa 1 test_float8_current_scaling_exact.py + run_default_fa 1 test_float8blockwisetensor.py test $_fus_attn = auto -o $_fus_attn = ck && run 1 test_cpu_offloading.py test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py run_default_fa 1 test_fused_rope.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 901e5ec9f..13280028a 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -15,7 +15,7 @@ add_executable(test_operator test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu #CUDA-only test + test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu @@ -41,7 +41,6 @@ if(USE_ROCM) get_target_property(test_cuda_sources test_operator SOURCES) # Remove CUDA-only tests and add ROCm specific ones list(REMOVE_ITEM test_cuda_sources - test_cast_float8blockwise.cu test_swizzle.cu test_grouped_gemm.cu) list(APPEND test_cuda_sources diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index b43cc6bd8..1dfed7a73 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -67,7 +67,7 @@ void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale qscale = ldexpf(1.0f, static_cast(exp) - 127); } - float qscale_inv = 1.0 / qscale; + float qscale_inv = 1.0f / qscale; *qscale_out = qscale; *qscale_inv_out = qscale_inv; } @@ -227,17 +227,22 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method } inline size_t scale_align_stride(size_t inner_elements) { +#ifdef __HIP_PLATFORM_AMD__ + return inner_elements; +#else return ((inner_elements + 4u - 1u) / 4u) * 4u; +#endif }; void compare_scaling_factors(const std::string& name, const float* test, const float* ref, const size_t row_blocks, const size_t col_blocks, - const size_t test_stride, const size_t ref_stride) { + const size_t test_stride, const size_t ref_stride, + const float atol = 1e-6f) { for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int test_idx = i * test_stride + j; const int ref_idx = i * ref_stride + j; - ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol) << "Error in " << name << std::endl << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "," << ref_idx; @@ -247,13 +252,14 @@ void compare_scaling_factors(const std::string& name, const float* test, const f void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, const float* ref, const size_t rows, - const size_t col_blocks) { + const size_t col_blocks, + const float atol = 1e-6f) { const size_t test_stride = scale_align_stride(rows); for (int i = 0; i < rows; ++i) { for (int j = 0; j < col_blocks; ++j) { const int test_idx = i + test_stride * j; const int ref_idx = i + rows * j; - ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol) << "Error in " << name << std::endl << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "," << ref_idx; diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 4d24fbb1e..8b050178a 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -46,7 +46,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: # TODO replace with call to fp8.py when recipe added. if IS_HIP_EXTENSION: - recipe_available = get_device_compute_capability() >= (9, 0) + recipe_available = get_device_compute_capability() >= (9, 5) else: recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 87af7c3f8..2273253ec 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -636,8 +636,8 @@ struct TypeExtrema { static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; static constexpr float max_inverse = 1.0 / max; #else - static constexpr float max = 448.0f; - static constexpr float max_inverse = 1.0 / max; + static float max; + static float max_inverse; #endif }; diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 3dc9b250d..6f3674569 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -1,27 +1,23 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/util/cuda_runtime.h" -#ifndef __HIP_PLATFORM_AMD__ #include "common/util/ptx.cuh" -#endif #include "common/utils.cuh" #ifndef __HIP_PLATFORM_AMD__ @@ -60,8 +56,12 @@ constexpr size_t BLOCK_TILE_DIM = 128; constexpr size_t WARP_TILE_DIM_X = 64; constexpr size_t WARP_TILE_DIM_Y = 32; constexpr size_t THREAD_TILE_DIM_X = 8; +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t THREAD_TILE_DIM_Y = 4; +#else constexpr size_t THREAD_TILE_DIM_Y = 8; #endif +#endif #ifdef TMA_HW_SUPPORTED constexpr size_t NUM_BYTES_PER_BANK = 4; @@ -82,6 +82,15 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP #define MIN(a, b) (a < b ? a : b) +#ifdef __HIP_PLATFORM_AMD__ +__device__ __forceinline__ float blockwise_warp_reduce_max(float val) { +#pragma unroll + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) + val = fmaxf(val, __shfl_down(val, delta, kThreadsPerWarp)); + return val; +} +#endif + #ifdef TMA_HW_SUPPORTED template __global__ void __launch_bounds__(THREADS_PER_BLOCK) @@ -151,7 +160,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } } // Reduce amax in the warp (32x32 tile) +#ifdef __HIP_PLATFORM_AMD__ + warp_tile_amax = blockwise_warp_reduce_max(amax); +#else warp_tile_amax = warp_reduce_max(amax); +#endif // broadcast the amax to all threads in a warp from the lane 0 constexpr int lane_zero = 0; warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); @@ -379,7 +392,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } // Reduce amax in the warp (32x32 tile) +#ifdef __HIP_PLATFORM_AMD__ + warp_tile_amax = blockwise_warp_reduce_max(amax); +#else warp_tile_amax = warp_reduce_max(amax); +#endif // broadcast the amax to all threads in a warp from the lane 0 constexpr int lane_zero = 0; warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero); diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 272220b7a..c1ff6e951 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -1,21 +1,19 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include #include #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include "common/common.h" From bdf905e8d21b3709c6f365bb1e630c3a04a6a4d1 Mon Sep 17 00:00:00 2001 From: asdfvg123 Date: Thu, 4 Jun 2026 23:01:45 +0000 Subject: [PATCH 3/3] fix constexpr chain, arch guard --- ci/pytorch.sh | 2 +- .../cpp/operator/test_cast_float8blockwise.cu | 22 ++++++++++---- tests/pytorch/test_float8blockwisetensor.py | 2 +- transformer_engine/common/common.h | 7 +++-- .../quantize_transpose_square_blockwise.cu | 30 ++++++++++++------- transformer_engine/pytorch/quantization.py | 5 +++- 6 files changed, 47 insertions(+), 21 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 1cd324e36..3bb282bc1 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -49,9 +49,9 @@ run_test_config(){ fi run 1 test_cuda_graphs.py run_default_fa 1 test_deferred_init.py - run_default_fa 1 test_quantized_tensor.py run_default_fa 1 test_float8_current_scaling_exact.py run_default_fa 1 test_float8blockwisetensor.py + run_default_fa 1 test_quantized_tensor.py test $_fus_attn = auto -o $_fus_attn = ck && run 1 test_cpu_offloading.py test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py run_default_fa 1 test_fused_rope.py diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index 1dfed7a73..db2cb7171 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -67,7 +67,7 @@ void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale qscale = ldexpf(1.0f, static_cast(exp) - 127); } - float qscale_inv = 1.0f / qscale; + float qscale_inv = 1.0 / qscale; *qscale_out = qscale; *qscale_inv_out = qscale_inv; } @@ -236,13 +236,19 @@ inline size_t scale_align_stride(size_t inner_elements) { void compare_scaling_factors(const std::string& name, const float* test, const float* ref, const size_t row_blocks, const size_t col_blocks, - const size_t test_stride, const size_t ref_stride, - const float atol = 1e-6f) { + const size_t test_stride, const size_t ref_stride) { +#ifdef __HIP_PLATFORM_AMD__ + const float atol = 1e-6f; +#endif for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int test_idx = i * test_stride + j; const int ref_idx = i * ref_stride + j; +#ifdef __HIP_PLATFORM_AMD__ ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol) +#else + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) +#endif << "Error in " << name << std::endl << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "," << ref_idx; @@ -252,14 +258,20 @@ void compare_scaling_factors(const std::string& name, const float* test, const f void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, const float* ref, const size_t rows, - const size_t col_blocks, - const float atol = 1e-6f) { + const size_t col_blocks) { +#ifdef __HIP_PLATFORM_AMD__ + const float atol = 1e-6f; +#endif const size_t test_stride = scale_align_stride(rows); for (int i = 0; i < rows; ++i) { for (int j = 0; j < col_blocks; ++j) { const int test_idx = i + test_stride * j; const int ref_idx = i + rows * j; +#ifdef __HIP_PLATFORM_AMD__ ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol) +#else + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) +#endif << "Error in " << name << std::endl << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "," << ref_idx; diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 8b050178a..4d24fbb1e 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -46,7 +46,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: # TODO replace with call to fp8.py when recipe added. if IS_HIP_EXTENSION: - recipe_available = get_device_compute_capability() >= (9, 5) + recipe_available = get_device_compute_capability() >= (9, 0) else: recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 2273253ec..295af2c99 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -635,9 +635,12 @@ struct TypeExtrema { #elif defined(__HIP_DEVICE_COMPILE__) static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; static constexpr float max_inverse = 1.0 / max; +#elif defined(HIP_FP8_TYPE_FNUZ) + static constexpr float max = 240.0f; + static constexpr float max_inverse = 1.0 / max; #else - static float max; - static float max_inverse; + static constexpr float max = 448.0f; + static constexpr float max_inverse = 1.0 / max; #endif }; diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 6f3674569..0b5202668 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -91,7 +91,7 @@ __device__ __forceinline__ float blockwise_warp_reduce_max(float val) { } #endif -#ifdef TMA_HW_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, @@ -281,7 +281,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #endif } } -#endif // TMA_HW_SUPPORTED +#endif // __HIP_PLATFORM_AMD__ template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( @@ -495,7 +495,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } -#ifdef TMA_HW_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ template CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { CUtensorMapDataType dataType; @@ -513,7 +513,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); return tensor_map_output_trans; } -#endif // TMA_HW_SUPPORTED +#endif // __HIP_PLATFORM_AMD__ } // namespace } // namespace transformer_engine @@ -584,10 +584,10 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor return_transpose, kReturnTranspose, dim3 grid(num_blocks_x, num_blocks_y, 1); + +#ifndef __HIP_PLATFORM_AMD__ const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; - -#ifdef TMA_HW_SUPPORTED if (full_tile) { CUtensorMap tensor_map_output_trans; if (return_transpose) { @@ -603,11 +603,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale, noop_ptr); - } else -#else - (void)full_tile; -#endif - { + } else { block_scaled_cast_transpose_kernel_notaligned <<>>( @@ -619,6 +615,18 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, pow_2_scale, noop_ptr); } // full-tile +#else + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale, noop_ptr); +#endif // __HIP_PLATFORM_AMD__ ) // return_transpose ) // OutputType ) // InputType diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 6cae60ccd..71cf74f4c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -103,7 +103,10 @@ def check_nvfp4_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if IS_HIP_EXTENSION: - return True, "" + gpu_arch = get_device_compute_capability() + if gpu_arch >= (9, 0): + return True, "" + return False, "Device arch gfx9+ or newer is required for FP8 block scaling execution." if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" return (