diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention.cu b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention.cu new file mode 100644 index 000000000..185aa89a5 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention.cu @@ -0,0 +1,1120 @@ +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "cuda_compat.h" + +#include +#include "dtype_fp8.cuh" +#include "quant_utils.cuh" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + +using _B8x8 = uint2; + +////// Non temporal load stores /////// + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, + const float scale) { + union alignas(16) { + uint4 u4; + _B16x8 u16x8; + vllm::bf16_8_t b16x8; + } tmp; + if constexpr (std::is_same::value) { + tmp.u4 = vllm::fp8::scaled_convert(input, scale); + return tmp.u16x8; + } else if constexpr (std::is_same::value) { + tmp.b16x8 = vllm::fp8::scaled_convert( + input, scale); + return tmp.u16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +/////////////////////////////////////// + +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _B16x8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _B16x8 Klocal[KHELOOP]; + _B8x8 Klocalb8[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _B16x8 Vlocal[VHELOOP][VTLOOP]; + _B8x8 Vlocalb8[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + + const int warp_start_token_idx = + partition_start_token_idx + warpid * WARP_SIZE; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + // fetch block number for q and k + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } + + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; + #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } + + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + } else { + constexpr int X = 16 / sizeof(cache_t); + const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + const int head_elem = d * 8; + const int offset1 = head_elem / X; + const int offset2 = head_elem % X; + const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2; + Klocalb8[d] = *reinterpret_cast(k_ptr3); + } + } + + float alibi_slope[QHLOOP]; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int qhead_idx = h * 4 + lane4id; + alibi_slope[h] = (qhead_idx < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + qhead_idx] + : 0.f; + } + } + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } else { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + const _B8x8 Vlocalb8 = v_ptrh8be[d]; + Vlocal[h][b * BLOCK_SIZE / 8 + d] = + scaled_convert_b8x8(Vlocalb8, v_scale); + } + } + } + } + + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = + scaled_convert_b8x8(Klocalb8[d], k_scale); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); + if constexpr (KHELOOP > 8) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); + } // KHELOOP>8 + dout[h] *= scale; + } + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; + #pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); + } + dout[h] = tmp; + } + + const int lane4_token_idx = 4 * (global_token_idx >> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } + } // warp within context + + __syncthreads(); + + const int num_heads = gridDim.z * GQA_RATIO; + float* max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; + float* exp_sums_ptr = + exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + float global_qk_max = -FLT_MAX; + float warp_qk_max[NWARPS]; + const int head_idx = 4 * h + lane4id; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max[w] = shared_qk_max[w][head_idx]; + global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); + } + float global_exp_sum = 0.0f; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + global_exp_sum += + shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); + } + if (head_idx < GQA_RATIO) { + max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_qk_max; + exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_exp_sum; + } + const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * + __expf(qk_max[h] - global_qk_max); + dout[h] *= global_inv_sum_scale; + } + // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + _B16x4 logits[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + logits[h] = from_floatx4(dout[h]); + } + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } + } + } else { // warp in context + // iterate across heads + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc = {0}; + // iterate over tokens + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + } + } + } // warp in context + + __syncthreads(); + + if (warpid == 0) { + _B16x4 vout[QHLOOP][VHELOOP]; + // iterate across heads + scalar_t* out_ptr; + int out_num_partitions; + if (context_len > partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); + } + const int head_size_elem = vh * WARP_SIZE + laneid; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } + } + } + } +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; + float reg_max_logit = max_logits_ptr[valid_partition]; + float reg_max_logit2 = max_logits_ptr[valid_partition2]; + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float global_exp_sum = 0.0f; + float rescaled_exp_sum = exp_sums_ptr[valid_partition]; + float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; + global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + if (num_partitions > MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions){UNREACHABLE_CODE} + +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, const c10::optional& alibi_slopes, + float k_scale, float v_scale) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); + + constexpr int NTHR = PARTITION_SIZE; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately + if (max_context_len > PARTITION_SIZE) { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); + } +} + +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } + +void paged_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int64_t block_size, int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale) { + const int head_size = query.size(2); + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + vllm::Fp8KVCacheDataType::kAuto); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + vllm::Fp8KVCacheDataType::kAuto); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else { + TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_dtypes.h b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_dtypes.h new file mode 100644 index 000000000..64f86381d --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_dtypes.h @@ -0,0 +1,7 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" +#include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_generic.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_generic.cuh new file mode 100644 index 000000000..62409c0cc --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_kernels.cu new file mode 100644 index 000000000..8afa088ee --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_kernels.cu @@ -0,0 +1,1010 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#ifdef USE_ROCM + #include + #include "quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +#else + int NUM_THREADS = 128> +#endif +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +#else + int NUM_THREADS = 128, int PARTITION_SIZE = 512> +#endif +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_utils.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_utils.cuh new file mode 100644 index 000000000..e19f0e367 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/attention_utils.cuh @@ -0,0 +1,57 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cuda_compat.h" +#include "attention_dtypes.h" + +#include +#include + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += VLLM_SHFL_XOR_SYNC(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/cuda_compat.h b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/cuda_compat.h new file mode 100644 index 000000000..bfa8c0071 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/cuda_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/custom.cu b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/custom.cu new file mode 100644 index 000000000..fae1b4fbf --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/custom.cu @@ -0,0 +1,78 @@ +#include +#include +#include + +// declare templates for front (cpp) and back (cuda) sides of function: +// template + +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); +} + +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); + +// template +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + // if (N != in_b.numel()) + // throw std::invalid_argument("Size mismatch A.numel(): " + + // std::to_string(in_a.numel()) + // + ", B.numel(): " + + // std::to_string(in_b.numel())); + + // out_c.resize_({N}); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); +} + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, + const int N, cudaStream_t stream, const int CuCount); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx); + +void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int64_t solidx = 0) { + auto M = in_a.size(0); + auto K = in_a.size(1); + + LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), solidx); +} +// instantiate the CPP template for T=float: +// template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor +// out_c); + +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream); + +void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) { + auto matA_sizes{in_a.sizes()}; + auto matB_sizes{in_b.sizes()}; + auto matO_sizes{out_c.sizes()}; + MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), matA_sizes[0], matA_sizes[1], + matB_sizes[0], matB_sizes[1], matO_sizes[0], matO_sizes[1], + at::cuda::getCurrentCUDAStream()); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/custom_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/custom_kernels.cu new file mode 100644 index 000000000..f7dba39bb --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/custom_kernels.cu @@ -0,0 +1,1309 @@ +#include +#include +#include +#include +#include "cuda_compat.h" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, + const int K) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; + __half2 acch2; + __half2 oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + } + + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + __half2 Af2; + __half2 Bf2; + float2 S; + + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 half. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + + // See comment above concerning the if guard. + if (threadid * 8 < K) { + acc[i] = S.x + S.y; // accumulation on float + } + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], 16); + + if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { + oval = __float22half2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +// define the kernel calling code: +// template +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<__half2*>(out_c); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + + int NUM_BLOCKS = M / rows_per_block; + + if (rows_per_block == 2) { + LLGemm1_kernel<2><<>>(af4, bf4, c, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel<8><<>>(af4, bf4, c, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel<16><<>>(af4, bf4, c, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +// instantiate the kernel template for T=float: +// template void AddGPUKernel(float *in_a, float *in_b, float *out_c, +// const int M, const int K, cudaStream_t stream); + +const unsigned int TILE_WIDTH = 32; + +// Compute C = A * B +__global__ void matrixMultiplyShared(float* A, float* B, float* C, int numARows, + int numAColumns, int numBRows, + int numBColumns, int numCRows, + int numCColumns) { + __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 + __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; + + int Row = blockDim.y * blockIdx.y + threadIdx.y; + int Col = blockDim.x * blockIdx.x + threadIdx.x; + float Cvalue = 0.0; + sA[threadIdx.y][threadIdx.x] = 0.0; + sB[threadIdx.y][threadIdx.x] = 0.0; + + for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { + if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { + sA[threadIdx.y][threadIdx.x] = + A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; + } else { + sA[threadIdx.y][threadIdx.x] = 0.0; + } + if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { + sB[threadIdx.y][threadIdx.x] = + B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; + } else { + sB[threadIdx.y][threadIdx.x] = 0.0; + } + __syncthreads(); + for (int j = 0; j < TILE_WIDTH; ++j) { + Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; + } + } + if (Row < numCRows && Col < numCColumns) { + C[Row * numCColumns + Col] = Cvalue; + } +} + +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream) { + // Initialize the grid and block dimensions + dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); + dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); + //@@ Launch the GPU Kernel here + matrixMultiplyShared<<>>( + in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, + numCColumns); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +template +__global__ __launch_bounds__(512) void HGEMV_WFPerRow( + int m, int n, const _Float16* A, int lda, const _Float16* x, _Float16* y) { + int num_row_per_block = CTA / nThreads_per_row; + int row_id = (blockIdx.x * num_row_per_block + threadIdx.y) * MT0; + int inc = (gridDim.x * num_row_per_block) * MT0; + + while (row_id < m) { + float2 sum2[MT0]; + +#pragma unroll + for (int i = 0; i < MT0; ++i) { + sum2[i] = {0.0, 0.0}; + } + + for (int j = threadIdx.x; j < n; j += (nThreads_per_row * MT1)) { + bool is_active = j < n; + if (is_active) { + float2 x2[MT1 >> 1]; +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) { + x2[offset >> 1] = {x[j + nThreads_per_row * offset], + x[j + nThreads_per_row * (offset + 1)]}; + } + float2 a2[MT0][MT1 >> 1]; +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) { + a2[i][offset >> 1] = { + A[(row_id + i) * n + j + nThreads_per_row * offset], + A[(row_id + i) * n + j + nThreads_per_row * (offset + 1)]}; + } + } + +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = 0; offset < (MT1 >> 1); offset++) { + sum2[i] += a2[i][offset] * x2[offset]; + } + } + } + } + float sum[MT0]; +#pragma unroll + for (int i = 0; i < MT0; i++) { + sum[i] = sum2[i].x + sum2[i].y; + } + +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = nThreads_per_row >> 1; offset >= 1; + offset = offset >> 1) { + sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); + } + } + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < MT0; i++) { + y[row_id + i] = sum[i]; + } + } + row_id += inc; + } +} + +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx = 0) { + // m -> M, n-> K + dim3 grid(1024); + dim3 block(64, 8); + if (solidx == 0) { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 1) { + HGEMV_WFPerRow<64, 512, 2, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 2) { + HGEMV_WFPerRow<64, 512, 1, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +///////////////////////////////////////////// + +#define DTYPE half + +__device__ __forceinline__ int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + // uint32_t commitColumn[YTILE]; + // for (uint32_t i = 0; i < YTILE; i++) { + // commitColumn[i] = 1; + //} + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / M; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); + while (n < Nrndp) { + #else + while (n < N) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + bigType bigB9[UNRL]; + bigType bigB10[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M; m++) { + uint32_t k_in = kBase + m * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (n >= N) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + #ifdef PCML + bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + #else + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + #ifdef PCML + if (n >= N) { + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + +#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + /*wvSpltK_hf:*/ \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else { \ + wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } \ + } + + switch (N_in) { + case 1: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_bfloat16.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_bfloat16.cuh new file mode 100644 index 000000000..97a25baa1 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_bfloat16.cuh @@ -0,0 +1,463 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#include + +namespace vllm { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template <> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template <> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template <> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat1622float2(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hadd2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template <> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template <> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template <> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(a, b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(bf162bf162(a), b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template <> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template <> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template <> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst = __float22bfloat162_rn(src); +#endif +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#endif +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#endif +} + +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +// Zero-out a variable. +inline __device__ void zero(__nv_bfloat16& dst) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. + dst = __ushort_as_bfloat16((unsigned short)0x0000U); +#endif +} + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_float16.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_float16.cuh new file mode 100644 index 000000000..3a1815f0e --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_float16.cuh @@ -0,0 +1,504 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifdef USE_ROCM + #include +#endif + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif +} + +inline __device__ float half_to_float(uint16_t h) { + float f; +#ifndef USE_ROCM + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); +#else + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); +#endif + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif +#else + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#ifndef USE_ROCM + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); +#endif + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } + +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_float32.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_float32.cuh new file mode 100644 index 000000000..7c6a686db --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_float32.cuh @@ -0,0 +1,251 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { return a * b; } + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { dst = src; } + +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } + +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } + +// From float to float. +inline __device__ float to_float(float u) { return u; } + +inline __device__ float2 to_float(float2 u) { return u; } + +inline __device__ float4 to_float(float4 u) { return u; } + +inline __device__ Float4_ to_float(Float4_ u) { return u; } + +inline __device__ Float8_ to_float(Float8_ u) { return u; } + +// Zero-out a variable. +inline __device__ void zero(float& dst) { dst = 0.f; } + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_fp8.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_fp8.cuh new file mode 100644 index 000000000..e714e321b --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/fused_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/fused_kernels.cu new file mode 100644 index 000000000..4f3eea456 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/fused_kernels.cu @@ -0,0 +1,195 @@ +#include +#include +#include +#include + +constexpr int WARP_SIZE = 64; + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm_Silu_kernel(float4* af4, __half2* bf4, _Float16* c, + const int d) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x; + const int row_addr_d = row_addr + d * blockDim.x; + // int row_addr_1 = row_addr + CUDA_NUM_THREADS; + // int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; + // int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + // float4 colB_elem4; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; + __half2 acch2; + __half2 oval; + + // rowA_elem4 = af4[row_addr + threadid]; + //__syncthreads(); + // rowA_elem4_1 = af4[row_addr_1 + threadid]; + // rowA_elem4_2 = af4[row_addr_2 + threadid]; + // rowA_elem4_3 = af4[row_addr_3 + threadid]; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) { + rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]); + rowA_elem4[2 * i + 1] = + load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]); + // rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid]; + //__syncthreads(); + } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + // __syncthreads(); + __half2 Af2; + __half2 Bf2; + float2 S; + // auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<_Float16*>(out_c); + const int d = M / 2; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm_Silu_kernel<2> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 4) { + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 8) { + LLGemm_Silu_kernel<8> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 16) { + LLGemm_Silu_kernel<16> + <<>>(af4, bf4, c, d); + } else { + NUM_BLOCKS = M / 4; + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/hip_float8.h b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/hip_float8.h new file mode 100644 index 000000000..f9c80fcde --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/hip_float8.h @@ -0,0 +1,137 @@ +#pragma once + +#ifdef __HIPCC__ + #include +#else + #include + #include + #include + #include +#endif + +#include "hip_float8_impl.h" + +struct alignas(1) hip_fp8 { + struct from_bits_t {}; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) {} + +#ifdef __HIP__MI300__ + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) {} + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, + true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) {} + +#ifdef __HIP__MI300__ + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" + : "=v"(fval) + : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( + data); + } +}; + +namespace std { +inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } +inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } +} // namespace std + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { + return os << float(f8); +} + +// all + operator overloading with mixed types +// mixed types, always converts to f32, does computation in f32, and returns +// float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { + return (fa + float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { + return (float(a) + fb); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { + return hip_fp8(float(a) + float(b)); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { + return a = hip_fp8(float(a) + float(b)); +} + +// overloading multiplication, always returns float, +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { + return float(a) * float(b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { + return (a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { + return (float(a) * b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { + return ((float)a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { + return ((float)a * float(b)); +} + +// overloading for compare +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { + return (a.data == b.data); +} +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { + return (a.data != b.data); +} + +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { + return static_cast(a) >= static_cast(b); +} +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { + return static_cast(a) > static_cast(b); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/hip_float8_impl.h b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/hip_float8_impl.h new file mode 100644 index 000000000..90251c353 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/hip_float8_impl.h @@ -0,0 +1,316 @@ +#pragma once + +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + +#ifdef __HIPCC__ + #define HIP_FP8_HOST_DEVICE __host__ __device__ + #define HIP_FP8_HOST __host__ + #define HIP_FP8_DEVICE __device__ +#else + #define HIP_FP8_HOST_DEVICE + #define HIP_FP8_HOST + #define HIP_FP8_DEVICE +#endif + +namespace hip_fp8_impl { + +#ifdef __HIP__MI300__ +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; +} +#endif // __HIP__MI300__ + +HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) +HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } +#endif + +template +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, + uint32_t rng = 0) { +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } + } else { + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } + } + } else { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } else { + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the +fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no + // difference for this case, act_exponent could be + // larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit + // that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + + T fInf, fNegInf, fNaN, fNeg0; + +#ifdef __HIPCC__ + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else +#endif + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; + } + } else { + if (x == 0x80) { + return fNeg0; + } + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); +} + +} // namespace hip_fp8_impl diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/paged_attn_ops.cpp b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/paged_attn_ops.cpp new file mode 100644 index 000000000..468433ba1 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/paged_attn_ops.cpp @@ -0,0 +1,26 @@ +#include "paged_attn_ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("paged_attention", &paged_attention, + "paged_attention(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()" + ); + m.def("LLMM1", &LLMM1, + "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " + "()"); + m.def("LLMM_Silu", &LLMM_Silu, + "LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) " + "-> ()"); + m.def("wvSpltK", &wvSpltK, + "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + " int CuCount) -> ()"); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/paged_attn_ops.h b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/paged_attn_ops.h new file mode 100644 index 000000000..18c72f937 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/paged_attn_ops.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount); + +void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, + double scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/quant_utils.cuh b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/quant_utils.cuh new file mode 100644 index 000000000..25cab5a91 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/csrc/quant_utils.cuh @@ -0,0 +1,710 @@ +#pragma once +#include "hip_float8.h" + +#include +#include +#include + +#include "attention_dtypes.h" + +namespace vllm { +#ifdef USE_ROCM + +namespace fp8 { + #ifdef ENABLE_FP8 + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; +} + +template +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; + #endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float vec_conversion(const uint8_t& a) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; + #else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; + #endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + hip_fp8 res{__bfloat162float(a)}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const float& a) { + hip_fp8 f8(a); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +// float2 -> half2 +template <> +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +// Float4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; +} + +// Float4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +// Float8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +// float2 -> bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; +} + +// Float4 -> bfloat162x2 +template <> +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; +} + +// Float8 -> bfloat162x4 +template <> +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; +} + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains + + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP + + */ + +// fp8 -> half +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data / scale)}; + return f8.data; +} + +// halfx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; + #endif +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; +} + +// bf16x2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +} + +// bf16x4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, float scale) { + hip_fp8 f8(a); + return f8.data; +} + +// floatx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; + #endif +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; +} + #endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } + #endif + assert(false); + return {}; // Squash missing return statement warning +} + +template +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } + #endif + assert(false); + return {}; // Squash missing return statement warning +} + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/setup.py b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/setup.py new file mode 100644 index 000000000..ef5132b54 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/setup.py @@ -0,0 +1,241 @@ +import sys +import warnings +import os +import re +import ast +import shutil +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, + ROCM_HOME, + IS_HIP_EXTENSION, +) + + +this_dir = os.path.dirname(os.path.abspath(__file__)) +PACKAGE_NAME='pagedAttn' +BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") + +if BUILD_TARGET == "auto": + if IS_HIP_EXTENSION: + IS_ROCM = True + else: + IS_ROCM = False +else: + if BUILD_TARGET == "cuda": + IS_ROCM = False + elif BUILD_TARGET == "rocm": + IS_ROCM = True + +FORCE_CXX11_ABI=False + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f'linux_{platform.uname().machine}' + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def check_if_rocm_home_none(global_option: str) -> None: + if ROCM_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but hipcc was not found." + ) + + +def append_nvcc_threads(nvcc_extra_args): + nvcc_threads = os.getenv("NVCC_THREADS") or "4" + return nvcc_extra_args + ["--threads", nvcc_threads] + + +def rename_cpp_to_cu(pth): + ret=[] + dst=pth.replace('csrc','build') + for entry in os.listdir(pth): + if entry.endswith(".cpp") or entry.endswith(".cu"): + newName=entry.replace(".cpp", ".cu") + shutil.copy(f'{pth}/{entry}', f'{dst}/{newName}') + ret.append(f'{dst}/{newName}') + return ret + + +def validate_and_update_archs(archs): + # List of allowed architectures + allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1100"] + + # Validate if each element in archs is in allowed_archs + assert all( + arch in allowed_archs for arch in archs + ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" + + +cmdclass = {} +ext_modules = [] + +if IS_ROCM: + #use codegen get code dispatch + if not os.path.exists(f"{this_dir}/build"): + os.makedirs(f"{this_dir}/build") + + print(f"\n\ntorch.__version__ = {torch.__version__}\n\n") + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + cc_flag = [] + + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + cc_flag = [f"--offload-arch={arch}" for arch in archs] + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + + renamed_sources=rename_cpp_to_cu(f"{this_dir}/csrc") + + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": + [ + "-O3","-std=c++17", + "-mllvm", "-enable-post-misched=0", + "-DUSE_PROF_API=1", + "-D__HIP_PLATFORM_HCC__=1", + "-D__HIP_PLATFORM_AMD__=1", + "-U__HIP_NO_HALF_CONVERSIONS__", + "-U__HIP_NO_HALF_OPERATORS__" + ] + + generator_flag + + cc_flag + , + } + + include_dirs = [ + f"{this_dir}/csrc", + ] + + ext_modules.append( + CUDAExtension( + name=PACKAGE_NAME, + sources=renamed_sources, + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + ) + ) +else: + raise NotImplementedError("Only ROCM is supported") + +class NinjaBuildExtension(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + # do not override env MAX_JOBS if already exists + if not os.environ.get("MAX_JOBS"): + import psutil + + # calculate the maximum allowed NUM_JOBS based on cores + max_num_jobs_cores = max(1, os.cpu_count() // 2) + + # calculate the maximum allowed NUM_JOBS based on free memory + free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB + max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4 + + # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation + max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) + os.environ["MAX_JOBS"] = str(max_jobs) + + super().__init__(*args, **kwargs) + + +setup( + name=PACKAGE_NAME, + version= "0.1.0", + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + ) + ), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"build_ext": NinjaBuildExtension}, + python_requires=">=3.8", + install_requires=[ + "torch", + "einops", + ], + setup_requires=[ + "packaging", + "psutil", + "ninja", + ], +) +if os.path.exists(f"{this_dir}/build"): + shutil.rmtree(f"{this_dir}/build") + shutil.rmtree(f"./.eggs") + shutil.rmtree(f"./build") + shutil.rmtree(f"./pagedAttn.egg-info") \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/tests/test_paged_attn.py b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/tests/test_paged_attn.py new file mode 100644 index 000000000..cad93d734 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/paged_attn/tests/test_paged_attn.py @@ -0,0 +1,413 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch +import pagedAttn as pa +import numpy as np +from typing import (List, Optional, Tuple, Union) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# This will change depending on the compute capability. +# - 512 as a buffer +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = (60 * 1024) // FLOAT32_BYTES - 512 +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +# FlashAttention forward only supports head dimension at most 128 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +# HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [64, 128] + +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto"] +# KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + +def seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables_lst = block_tables.cpu().tolist() + seq_lens_lst = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables_lst[i] + seq_len = int(seq_lens_lst[i]) + + keys_lst: List[torch.Tensor] = [] + values_lst: List[torch.Tensor] = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, :, :, block_offset] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, +} + +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + + seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + # elif cache_dtype == 'fp8': + # _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + # elif cache_dtype == 'fp8': + # _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + +def test_paged_attention( + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): + pytest.skip() + print("running test:", (num_seqs, num_heads, head_size, block_size, dtype, kv_cache_dtype)) + seed_everything(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables_lst: List[List[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + k_scale = v_scale = 1.0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version in ("rocm"): + PARTITION_SIZE = 1024 if version == "v2" else 512 + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + + pa.paged_attention( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + # opcheck(torch.ops._rocm_C.paged_attention, + # (output, exp_sums, max_logits, tmp_output, query, + # key_cache, value_cache, num_kv_heads, scale, block_tables, + # seq_lens, block_size, max_seq_len, alibi_slopes, + # kv_cache_dtype, k_scale, v_scale), + # cond=(head_size == HEAD_SIZES[0])) + + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + # if kv_cache_dtype == "fp8": + # # Convert cache data back to dtype. + # x = 16 // torch.tensor([], dtype=dtype).element_size() + # key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + # block_size, x) + # dequantized_key_cache = torch.empty(size=key_cache_shape, + # dtype=dtype, + # device=device) + # ops.convert_fp8(dequantized_key_cache, key_cache) + # key_cache = dequantized_key_cache + + # value_cache_shape = value_cache.shape + # dequantized_value_cache = torch.empty(size=value_cache_shape, + # dtype=dtype, + # device=device) + # ops.convert_fp8(dequantized_value_cache, value_cache) + # value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = 1e-3 + rtol = 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + print(" test passed!") + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs: List[torch.Tensor] = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + + return torch.cat(ref_outputs, dim=0) + + + # version: str, + # num_seqs: int, + # num_heads: Tuple[int, int], + # head_size: int, + # use_alibi: bool, + # block_size: int, + # dtype: torch.dtype, + # kv_cache_dtype: str, + # seed: int, + # device: str, +if __name__ == "__main__": + version = "rocm" + for num_seqs in NUM_GEN_SEQS: + for num_heads in NUM_HEADS: + for head_size in HEAD_SIZES: + for use_alibi in USE_ALIBI: + for block_size in BLOCK_SIZES: + for dtype in DTYPES: + for kv_cache_dtype in KV_CACHE_DTYPE: + for device in CUDA_DEVICES: + test_paged_attention(version=version, + num_seqs=num_seqs, + num_heads=num_heads, + head_size=head_size, + use_alibi=use_alibi, + block_size=block_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + seed=0, + device=device) \ No newline at end of file