Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/custom_quickreduce.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/all.h> // 1222

#ifdef USE_ROCM

Expand Down
27 changes: 22 additions & 5 deletions csrc/quickreduce/quick_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <vector>
#include <hip/hip_runtime.h>
#include "quick_reduce_impl.cuh"
// #define caltime

#define HIP_CHECK(err) \
do { \
Expand All @@ -22,13 +23,14 @@ template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) {
uint32_t data_offset, uint32_t flag_color,
int64_t data_size_per_phase) {
int block = blockIdx.x;
int grid = gridDim.x;

while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color);
flag_color, data_size_per_phase);
block += grid;
flag_color++;
}
Expand All @@ -41,21 +43,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
}

enum QuickReduceQuantLevel {
Expand Down Expand Up @@ -173,6 +175,12 @@ struct DeviceComms {
uint32_t num_blocks = divceil(msg_size, kTileSize);
uint32_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
#ifdef caltime
hipEvent_t start, end;
hipEventCreate(&start);
hipEventCreate(&end);
hipEventRecord(start, stream);
#endif
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
Expand All @@ -187,6 +195,15 @@ struct DeviceComms {
TWOSHOT_DISPATCH(CodecFP)
break;
}
#ifdef caltime
hipEventRecord(end, stream);
hipEventSynchronize(end);
float elapsed_time;
hipEventElapsedTime(&elapsed_time, start, end);
if (rank == 0) {
printf("msg_size:%u, quant_level:%d, qr_latency:%f\n", msg_size, quant_level, elapsed_time * 1000);
}
#endif
HIP_CHECK(cudaGetLastError());
// Rotate the flag color.
flag_color += divceil(N, grid);
Expand Down
107 changes: 32 additions & 75 deletions csrc/quickreduce/quick_reduce_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include <hip/hip_runtime.h>
#include "base.h"
#include <hip/hip_fp16.h>
#include <hip/hip_fp4.h>


namespace quickreduce {

Expand Down Expand Up @@ -391,14 +394,14 @@ struct CodecQ8 : public CodecBase {
static constexpr int kWorldSize = world_size;

// Codec tile size process by this workgroup.
// Each threads processes a fragment of f16x8_t (16B),
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
// Each threads processes a fragment of fp16x8_t (16B),
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 2176;
static constexpr int kRankTileScaleOffset = 2048;
static constexpr int kRankTileStride = 1152;
static constexpr int kRankTileScaleOffset = 1024;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0,
"kRankTileSize must be 16B aligned.");
"kRankTransmittedTileSize must be 16B aligned.");

static constexpr int kRankBufferTileStride =
kRankTileStride / sizeof(int32x4_t);
Expand All @@ -409,31 +412,22 @@ struct CodecQ8 : public CodecBase {

// Constants configuration

// {-1/128.0h, -1/128.0h}, f16x2_t
static constexpr int kScaleFactor =
std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
// {1/6.0h, 1/6.0h}, f16x2_t
static int constexpr kScaleFactor = 0x31553155;

// {1e-7, 1e-7}, f16x2_t
static constexpr int kScaleEpsilon =
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;

// {-128, -128}, f16x2_t
static constexpr int kRangeMin =
std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
// {+127, +127}, f16x2_t
static constexpr int kRangeMax =
std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;

// {+128, +128}, int16x2_t
static constexpr int kRangeBias = 0x00800080;

__quickreduce_device_inline__ CodecQ8(int thread, int rank)
: CodecBase(thread, rank) {}

__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
int32x4_t const* __restrict__ data) {
const int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];

// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
Expand All @@ -449,32 +443,21 @@ struct CodecQ8 : public CodecBase {
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}

// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));

for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}

// Pack 8 x q8 into int32x2_t
int32x2_t qw;
qw[0] = q[0] | (q[1] << 8);
qw[1] = q[2] | (q[3] << 8);
float con_scale = 1.0f; // 无缩放
int32_t qw;
__amd_fp16x2_storage_t* y = reinterpret_cast<__amd_fp16x2_storage_t*>(&w);
qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[0], con_scale, 0);
qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[1], con_scale, 1);
qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[2], con_scale, 2);
qw = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(qw, y[3], con_scale, 3);

// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr =
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);

Expand All @@ -490,45 +473,23 @@ struct CodecQ8 : public CodecBase {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);

int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
int32_t qw = __builtin_nontemporal_load(qw_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);

*recv_buffer += kRankBufferTileStride;

// Unpack q8 into fp16x8_t
int32x4_t w;
{
static uint constexpr kMask00FF = 0x00FF00FF;

// {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1024 = 0x64006400;

// {-1152.0, -1152.0}, fp16x2_t
static uint constexpr kHalf2_1152 = 0xE480E480;

#pragma unroll
for (int i = 0; i < 4; i++) {
if constexpr (std::is_same<T, half>::value) {
int32_t q8 =
((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
w[i] = packed_add<half>(q8, kHalf2_1152);
} else {
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
int32x4_t w;{
__amd_fp16x2_storage_t* y = reinterpret_cast<__amd_fp16x2_storage_t*>(&w);
__hip_fp4x2_storage_t* qww = reinterpret_cast<__hip_fp4x2_storage_t*>(&qw);
y[0] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[0], 1.0f, 0);
y[1] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[1], 1.0f, 0);
y[2] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[2], 1.0f, 0);
y[3] = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(qww[3], 1.0f, 0);
}

// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
Expand All @@ -538,7 +499,6 @@ struct CodecQ8 : public CodecBase {
}
}
};

// Twoshot All Reduce
template <typename T, class Codec, bool cast_bf2half>
struct AllReduceTwoshot {
Expand All @@ -553,13 +513,12 @@ struct AllReduceTwoshot {
int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
uint32_t flag_color, int64_t data_size_per_phase) {
// Topology
int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank);
int block_id = blockIdx.x;
int grid_size = gridDim.x;
// --------------------------------------------------------
// Read input into registers
int32x4_t tA[kAtoms];
Expand Down Expand Up @@ -588,12 +547,10 @@ struct AllReduceTwoshot {
// rank responsible for this segment.
uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset =
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;

uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset =
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;

for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
Expand Down
Loading