From 80b7731d4668394ef3fc0d420259876e182e544f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 30 Oct 2025 07:11:31 +0000 Subject: [PATCH] support fp4, but need del caltime Signed-off-by: Haoyang Li --- csrc/custom_quickreduce.cu | 2 +- csrc/quickreduce/quick_reduce.h | 27 +++++-- csrc/quickreduce/quick_reduce_impl.cuh | 107 ++++++++----------------- 3 files changed, 55 insertions(+), 81 deletions(-) diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 33d0d4a7226e..e098194d28a7 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include // 1222 #ifdef USE_ROCM diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 4fe4c44be7eb..ed2529f48313 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -3,6 +3,7 @@ #include #include #include "quick_reduce_impl.cuh" +// #define caltime #define HIP_CHECK(err) \ do { \ @@ -22,13 +23,14 @@ template __global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, int rank, uint8_t** dbuffer_list, - uint32_t data_offset, uint32_t flag_color) { + uint32_t data_offset, uint32_t flag_color, + int64_t data_size_per_phase) { int block = blockIdx.x; int grid = gridDim.x; while (block < num_blocks) { AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, - flag_color); + flag_color, data_size_per_phase); block += grid; flag_color++; } @@ -41,21 +43,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } else if (world_size == 4) { \ using LineCodec = __codec; \ using AllReduceKernel = AllReduceTwoshot; \ hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } else if (world_size == 8) { \ using LineCodec = __codec; \ using AllReduceKernel = AllReduceTwoshot; \ hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } enum QuickReduceQuantLevel { @@ -173,6 +175,12 @@ struct DeviceComms { uint32_t num_blocks = divceil(msg_size, kTileSize); uint32_t grid = min(kMaxNumBlocks, num_blocks); auto quant_level_ = static_cast(quant_level); +#ifdef caltime + hipEvent_t start, end; + hipEventCreate(&start); + hipEventCreate(&end); + hipEventRecord(start, stream); +#endif switch (quant_level_) { case QuickReduceQuantLevel::INT8: TWOSHOT_DISPATCH(CodecQ8) @@ -187,6 +195,15 @@ struct DeviceComms { TWOSHOT_DISPATCH(CodecFP) break; } +#ifdef caltime + hipEventRecord(end, stream); + hipEventSynchronize(end); + float elapsed_time; + hipEventElapsedTime(&elapsed_time, start, end); + if (rank == 0) { + printf("msg_size:%u, quant_level:%d, qr_latency:%f\n", msg_size, quant_level, elapsed_time * 1000); + } +#endif HIP_CHECK(cudaGetLastError()); // Rotate the flag color. flag_color += divceil(N, grid); diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 17816c552d25..a35c816bc48c 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -2,6 +2,9 @@ #include #include "base.h" +#include +#include + namespace quickreduce { @@ -391,14 +394,14 @@ struct CodecQ8 : public CodecBase { static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. - // Each threads processes a fragment of f16x8_t (16B), - // into a int8x8_t (8B) and a f16 scale shared among 32 values. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. static constexpr int kRankAtoms = kAtoms / kWorldSize; - static constexpr int kRankTileStride = 2176; - static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); + "kRankTransmittedTileSize must be 16B aligned."); static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); @@ -409,31 +412,22 @@ struct CodecQ8 : public CodecBase { // Constants configuration - // {-1/128.0h, -1/128.0h}, f16x2_t - static constexpr int kScaleFactor = - std::is_same::value ? 0xA000A000 : 0xBC00BC00; + // {1/6.0h, 1/6.0h}, f16x2_t + static int constexpr kScaleFactor = 0x31553155; // {1e-7, 1e-7}, f16x2_t static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; - // {-128, -128}, f16x2_t - static constexpr int kRangeMin = - std::is_same::value ? 0xD800D800 : 0xC300C300; - // {+127, +127}, f16x2_t - static constexpr int kRangeMax = - std::is_same::value ? 0x57F057F0 : 0x42FE42FE; - - // {+128, +128}, int16x2_t - static constexpr int kRangeBias = 0x00800080; __quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + const int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group // In 2 blocks of values, upper/lower halves of the f16x2_t int wblockmax = group_abs_max(atom); @@ -449,32 +443,21 @@ struct CodecQ8 : public CodecBase { int32x4_t w; for (int i = 0; i < 4; i++) { w[i] = packed_mul(atom[i], encoding_scale); - w[i] = packed_max(w[i], kRangeMin); - w[i] = packed_min(w[i], kRangeMax); } - // Convert from f16x2_t to uint16x2_t - int32x4_t q; - { - int16_t* qi = reinterpret_cast(&q); - T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); - - for (int i = 0; i < 4; i++) { - q[i] = packed_add(q[i], kRangeBias); - } - } - - // Pack 8 x q8 into int32x2_t - int32x2_t qw; - qw[0] = q[0] | (q[1] << 8); - qw[1] = q[2] | (q[3] << 8); + float con_scale = 1.0f; // 无缩放 + int32_t qw; + __amd_fp16x2_storage_t* y = reinterpret_cast<__amd_fp16x2_storage_t*>(&w); + qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[0], con_scale, 0); + qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[1], con_scale, 1); + qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[2], con_scale, 2); + qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[3], con_scale, 3); // Write quantized atom to send_buffer // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); @@ -490,45 +473,23 @@ struct CodecQ8 : public CodecBase { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int32_t qw = __builtin_nontemporal_load(qw_ptr); int qs = __builtin_nontemporal_load(qs_ptr); *recv_buffer += kRankBufferTileStride; - // Unpack q8 into fp16x8_t - int32x4_t w; - { - static uint constexpr kMask00FF = 0x00FF00FF; - - // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1024 = 0x64006400; - - // {-1152.0, -1152.0}, fp16x2_t - static uint constexpr kHalf2_1152 = 0xE480E480; - -#pragma unroll - for (int i = 0; i < 4; i++) { - if constexpr (std::is_same::value) { - int32_t q8 = - ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; - w[i] = packed_add(q8, kHalf2_1152); - } else { - int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; - int16_t low = static_cast(int16_2 & 0xFFFF); - int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); - nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); - nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); - nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); - int32_t packed_bf16 = *reinterpret_cast(&bf2); - w[i] = packed_add(packed_bf16, kRangeMin); - } - } + int32x4_t w;{ + __amd_fp16x2_storage_t* y = reinterpret_cast<__amd_fp16x2_storage_t*>(&w); + __hip_fp4x2_storage_t* qww = reinterpret_cast<__hip_fp4x2_storage_t*>(&qw); + y[0] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[0], 1.0f, 0); + y[1] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[1], 1.0f, 0); + y[2] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[2], 1.0f, 0); + y[3] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[3], 1.0f, 0); } - // Apply decoding scales for (int i = 0; i < 4; i++) { w[i] = packed_mul(w[i], qs); @@ -538,7 +499,6 @@ struct CodecQ8 : public CodecBase { } } }; - // Twoshot All Reduce template struct AllReduceTwoshot { @@ -553,13 +513,12 @@ struct AllReduceTwoshot { int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers uint32_t const data_offset, // offset to start of the data buffer - uint32_t flag_color) { + uint32_t flag_color, int64_t data_size_per_phase) { // Topology int thread = threadIdx.x + threadIdx.y * kWavefront; uint8_t* rank_buffer = buffer_list[rank]; Codec codec(thread, rank); int block_id = blockIdx.x; - int grid_size = gridDim.x; // -------------------------------------------------------- // Read input into registers int32x4_t tA[kAtoms]; @@ -588,12 +547,10 @@ struct AllReduceTwoshot { // rank responsible for this segment. uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize; - uint32_t comm_data1_offset = - grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset; uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); - uint32_t comm_flags1_offset = - grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset; for (int r = 0; r < kWorldSize; r++) { int32x4_t* send_buffer =