diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp index 1fde5ac684..1870765276 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp @@ -66,8 +66,8 @@ at::Tensor rope_qkv_varseq_prefill_meta( bool /* k_norm */, bool /* update_kv */, std::optional /* amax_qkv */, - std::optional /* kv_quant_scale_precomputed */ -) { + std::optional /* kv_quant_scale_precomputed */, + bool /* symmetric_quant */) { return at::empty_like(XQ); } @@ -95,8 +95,8 @@ at::Tensor rope_qkv_decoding_meta( std::optional /* qparam_v */, bool /* k_norm */, bool /* update_kv */, - std::optional /* amax_qkv */ -) { + std::optional /* amax_qkv */, + bool /* symmetric_quant */) { return at::empty_like(XQ); } @@ -233,7 +233,8 @@ std::tuple dequantize_fp8_cache_meta( std::optional qparam_k, std::optional /* qparam_v */, std::optional /* block_tables */, - int64_t /* page_size */) { + int64_t /* page_size */, + std::optional /*symmetric*/) { const at::SymInt B_KV = cache_K.sym_size(0); const at::SymInt MAX_T = cache_K.sym_size(1); const at::SymInt N_KVH = cache_K.sym_size(2); diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index a203a308a7..dc50366159 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -87,7 +87,10 @@ enum class PositionEmbeddingMode { ROPE = 0, XPOS = 1, NOPE = 2 }; enum class KVQuantRecipe { perTokenScaling = 0, perHeadScaling = 1 }; enum class QKV { Q, K, V }; -template +template < + typename T, + KVQuantRecipe recipe = KVQuantRecipe::perTokenScaling, + bool symmetric = false> DEVICE_INLINE void quantize_fp8_kv( fx4 dst, T* dst_row_q, @@ -632,6 +635,7 @@ quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm = false) { NUM_GROUPS, \ DTYPE, \ EMB_MODE, \ + SYMMETRIC_QUANT, \ VARSEQ_BATCH, \ VARSEQ_SEQPOS, \ THETA, \ @@ -650,46 +654,87 @@ quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm = false) { hi_freq_factor, \ write_k_back, \ k_norm) \ - FBGEMM_LAUNCH_KERNEL( \ - (rope_xpos_qkv_varseq_prefill_kernel_quantized< \ - EMB_MODE, \ - DTYPE, \ - NUM_GROUPS>), \ - blocks, \ - threads, \ - 0, \ - at::cuda::getCurrentCUDAStream(), \ - PTA_B(XQ, at::BFloat16, 3, 32), \ - PTA_B(XK, at::BFloat16, 3, 32), \ - PTA_B(XV, at::BFloat16, 3, 32), \ - PTA_B(cache_K, uint8_t, 4, 64), \ - PTA_B(cache_V, uint8_t, 4, 64), \ - qparam_k_ptr, \ - qparam_v_ptr, \ - PTA_B(XQ_O, at::BFloat16, 3, 32), \ - VARSEQ_BATCH, \ - VARSEQ_SEQPOS, \ - THETA, \ - GAMMA, \ - SCALE_BASE, \ - EXPO_OFFSET, \ - block_tables, \ - page_size, \ - block_tables_b_stride, \ - varseq_cache_seqpos, \ - actual_batch_size, \ - rope_scaling, \ - old_context_len, \ - scaling_factor, \ - lo_freq_factor, \ - hi_freq_factor, \ - write_k_back, \ - k_norm); + if (SYMMETRIC_QUANT) { \ + FBGEMM_LAUNCH_KERNEL( \ + (rope_xpos_qkv_varseq_prefill_kernel_quantized< \ + EMB_MODE, \ + DTYPE, \ + NUM_GROUPS, \ + true>), \ + blocks, \ + threads, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + PTA_B(XQ, at::BFloat16, 3, 32), \ + PTA_B(XK, at::BFloat16, 3, 32), \ + PTA_B(XV, at::BFloat16, 3, 32), \ + PTA_B(cache_K, uint8_t, 4, 64), \ + PTA_B(cache_V, uint8_t, 4, 64), \ + qparam_k_ptr, \ + qparam_v_ptr, \ + PTA_B(XQ_O, at::BFloat16, 3, 32), \ + VARSEQ_BATCH, \ + VARSEQ_SEQPOS, \ + THETA, \ + GAMMA, \ + SCALE_BASE, \ + EXPO_OFFSET, \ + block_tables, \ + page_size, \ + block_tables_b_stride, \ + varseq_cache_seqpos, \ + actual_batch_size, \ + rope_scaling, \ + old_context_len, \ + scaling_factor, \ + lo_freq_factor, \ + hi_freq_factor, \ + write_k_back, \ + k_norm); \ + } else { \ + FBGEMM_LAUNCH_KERNEL( \ + (rope_xpos_qkv_varseq_prefill_kernel_quantized< \ + EMB_MODE, \ + DTYPE, \ + NUM_GROUPS, \ + false>), \ + blocks, \ + threads, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + PTA_B(XQ, at::BFloat16, 3, 32), \ + PTA_B(XK, at::BFloat16, 3, 32), \ + PTA_B(XV, at::BFloat16, 3, 32), \ + PTA_B(cache_K, uint8_t, 4, 64), \ + PTA_B(cache_V, uint8_t, 4, 64), \ + qparam_k_ptr, \ + qparam_v_ptr, \ + PTA_B(XQ_O, at::BFloat16, 3, 32), \ + VARSEQ_BATCH, \ + VARSEQ_SEQPOS, \ + THETA, \ + GAMMA, \ + SCALE_BASE, \ + EXPO_OFFSET, \ + block_tables, \ + page_size, \ + block_tables_b_stride, \ + varseq_cache_seqpos, \ + actual_batch_size, \ + rope_scaling, \ + old_context_len, \ + scaling_factor, \ + lo_freq_factor, \ + hi_freq_factor, \ + write_k_back, \ + k_norm); \ + } template < PositionEmbeddingMode EmbMode, CacheLogicalDtype kCacheDtype, - int KVQuantNumGroups = 1> + int KVQuantNumGroups = 1, + bool symmetric = false> __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized( pta::PackedTensorAccessor32 XQ, // [B_T][N_H][D_H] @@ -699,10 +744,12 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized( XV, // [B_T][N_KVH][D_H] pta::PackedTensorAccessor64 cache_K, // [B][MAX_T][N_KVH][D_H] or - // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention + // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged + // attention pta::PackedTensorAccessor64 cache_V, // [B][MAX_T][N_KVH][D_H] or - // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention + // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged + // attention int32_t* qparam_k_ptr, int32_t* qparam_v_ptr, pta::PackedTensorAccessor32 @@ -873,7 +920,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized( qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); } } - quantize_fp8_kv(dst, dst_row_q, qparam_row); + quantize_fp8_kv( + dst, dst_row_q, qparam_row); } else if (kCacheDtype == CacheLogicalDtype::INT4) { CUDA_KERNEL_ASSERT(D_H_q - D_H / 2 == 4 * KVQuantNumGroups); quantize_int4_kv(dst, dst_row_q); @@ -1211,6 +1259,7 @@ at::Tensor nope_qkv_varseq_prefill( 1, CacheLogicalDtype::FP8, PositionEmbeddingMode::NOPE, + false, // symmetric_quant is False for FP8 NoPE varseq_batch_, varseq_seqpos_, 0, @@ -1240,6 +1289,7 @@ at::Tensor nope_qkv_varseq_prefill( num_groups_, CacheLogicalDtype::INT4, PositionEmbeddingMode::NOPE, + false, // symmetric_quant is False for INT4 NoPE varseq_batch_, varseq_seqpos_, 0, @@ -1406,6 +1456,7 @@ at::Tensor nope_qkv_decoding( 1, CacheLogicalDtype::FP8, PositionEmbeddingMode::NOPE, + false, // symmetric_quant is False for FP8 NoPE batch_, seqpos_, 0, @@ -1435,6 +1486,7 @@ at::Tensor nope_qkv_decoding( num_groups_, CacheLogicalDtype::INT4, PositionEmbeddingMode::NOPE, + false, // symmetric_quant is False for INT4 NoPE batch_, seqpos_, 0, @@ -1485,7 +1537,8 @@ at::Tensor rope_qkv_varseq_prefill( bool k_norm = false, bool update_kv = true, std::optional amax_qkv = std::nullopt, - std::optional kv_quant_scale_precomputed = std::nullopt) { + std::optional kv_quant_scale_precomputed = std::nullopt, + bool symmetric_quant = false) { auto B_T = XQ.size(0); auto N_H = XQ.size(1); auto N_KVH = 0; @@ -1630,6 +1683,7 @@ at::Tensor rope_qkv_varseq_prefill( 1, CacheLogicalDtype::FP8, PositionEmbeddingMode::ROPE, + symmetric_quant, // for FP8 RoPE varseq_batch_, varseq_seqpos_, theta, @@ -1659,6 +1713,7 @@ at::Tensor rope_qkv_varseq_prefill( num_groups_, CacheLogicalDtype::INT4, PositionEmbeddingMode::ROPE, + false, // symmetric_quant is False for INT4 RoPE varseq_batch_, varseq_seqpos_, theta, @@ -1791,6 +1846,7 @@ at::Tensor xpos_qkv_varseq_prefill( 1, CacheLogicalDtype::FP8, PositionEmbeddingMode::XPOS, + false, // symmetric_quant is False for FP8 XPOS varseq_batch_, varseq_seqpos_, theta, @@ -1820,6 +1876,7 @@ at::Tensor xpos_qkv_varseq_prefill( num_groups_, CacheLogicalDtype::INT4, PositionEmbeddingMode::XPOS, + false, // symmetric_quant is False for INT4 XPOS varseq_batch_, varseq_seqpos_, theta, @@ -1869,7 +1926,8 @@ at::Tensor rope_qkv_decoding( std::optional qparam_v = std::nullopt, bool k_norm = false, bool update_kv = true, - std::optional amax_qkv = std::nullopt) { + std::optional amax_qkv = std::nullopt, + bool symmetric_quant = false) { auto B = XQ.size(0); auto N_H = XQ.size(1); auto N_KVH = 0; @@ -2002,6 +2060,7 @@ at::Tensor rope_qkv_decoding( 1, CacheLogicalDtype::FP8, PositionEmbeddingMode::ROPE, + symmetric_quant, // symmetric_quant is True for FP8 RoPE nullptr, seqpos_, theta, @@ -2032,6 +2091,7 @@ at::Tensor rope_qkv_decoding( num_groups_, CacheLogicalDtype::INT4, PositionEmbeddingMode::ROPE, + false, // symmetric_quant is False for INT4 RoPE nullptr, seqpos_, theta, @@ -2159,6 +2219,7 @@ at::Tensor xpos_qkv_decoding( 1, CacheLogicalDtype::FP8, PositionEmbeddingMode::XPOS, + false, // symmetric_quant is False for FP8 XPOS nullptr, seqpos_, theta, @@ -2187,6 +2248,7 @@ at::Tensor xpos_qkv_decoding( num_groups_, CacheLogicalDtype::INT4, PositionEmbeddingMode::XPOS, + false, // symmetric_quant is False for INT4 XPOS nullptr, seqpos_, theta, @@ -2233,7 +2295,7 @@ DEVICE_INLINE uint32_t packComponents(uint32_t x_bits[4]) { return packed; } -template +template DEVICE_INLINE void quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) { if (do_norm) { @@ -2253,7 +2315,12 @@ quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) { warp_min = -warpReduceMax(-thread_min, mask); warp_max = warpReduceMax(thread_max, mask); - auto bounded_max = (warp_max - warp_min) / 2; + float bounded_max; + if (symmetric) { + bounded_max = warp_max; + } else { + bounded_max = (warp_max - warp_min) / 2; + } // max FP16 value is 65504.0f. // Divide by 2 to avoid overflow during // e4m3fn (NV) to e4m3fnuz (AMD) conversion @@ -2262,7 +2329,7 @@ quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) { bounded_max = std::min(bounded_max, scale_ub); scale = static_cast( std::max(bounded_max / FP8_E4M3_MAX::value, min_scaling_factor)); - shift = warp_min + FP8_E4M3_MAX::value * scale; + shift = symmetric ? 0 : warp_min + FP8_E4M3_MAX::value * scale; } else { // Support of per-head scaling is limited to reading a // pre-calculated scale from qparam tensor and using it for scaling the @@ -2291,11 +2358,16 @@ quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) { param_store = reinterpret_cast<__half2*>(&dst_row_q[0]); } CUDA_KERNEL_ASSERT(uintptr_t(param_store) % 4 == 0); - *param_store = __floats2half2_rn(scale, shift); + if (symmetric) { + float* param_store_fp32 = reinterpret_cast(param_store); + *param_store_fp32 = scale; + } else { + *param_store = __floats2half2_rn(scale, shift); + } } } #else -template +template DEVICE_INLINE void quantize_fp8_kv(fx4 dst, T* dst_row_, __half2* qparam, bool do_norm) {} std::vector quantize_fp8_per_tensor( diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h index 30363404e7..de9310c2f0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h @@ -77,7 +77,8 @@ at::Tensor rope_qkv_varseq_prefill( bool k_norm, bool update_kv, std::optional amax_qkv, - std::optional kv_quant_scale_precomputed); + std::optional kv_quant_scale_precomputed, + bool symmetric_quant); at::Tensor rope_qkv_decoding( at::Tensor XQ, @@ -103,7 +104,8 @@ at::Tensor rope_qkv_decoding( std::optional qparam_v, bool k_norm, bool update_kv, - std::optional amax_qkv); + std::optional amax_qkv, + bool symmetric_quant); at::Tensor xpos_qkv_varseq_prefill( at::Tensor XQ, @@ -172,7 +174,8 @@ std::tuple dequantize_fp8_cache( std::optional qparam_k, std::optional qparam_v, std::optional block_tables, - int64_t page_size); + int64_t page_size, + std::optional symmetric); at::Tensor quantize_qkv_per_head( at::Tensor amax, diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_defs.cpp index 885714a199..1bd1280723 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_defs.cpp @@ -17,9 +17,9 @@ namespace fbgemm_gpu { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!)? XK, Tensor? XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192" - ", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False,bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor"); + ", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False,bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None, bool symmetric_quant=False) -> Tensor"); m.def("rope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( - DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None) -> Tensor"); + DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None, bool symmetric_quant=False) -> Tensor"); m.def("nope_qkv_varseq_prefill(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING( DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor"); m.def("nope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING( @@ -32,7 +32,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None) -> (Tensor, Tensor)"); m.def( "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING( - DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)"); + DEFAULT_PAGE_SIZE) ", bool? symmetric=False) -> (Tensor, Tensor)"); m.def( "quantize_qkv_per_head(Tensor amax, Tensor XQKV, Tensor varseq_seqpos, Tensor? varseq_batch, Tensor? is_precalculated_qparam, Tensor cache_K, Tensor cache_V, Tensor XQ_O, int B, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor"); m.def( diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu index 000240393a..542dc50470 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu @@ -170,7 +170,7 @@ std::tuple dequantize_int4_cache( (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) -template +template __global__ void dequantize_fp8_cache_kernel( // This code currently represents FP8 version not int4 at::PackedTensorAccessor64 @@ -232,7 +232,15 @@ __global__ void dequantize_fp8_cache_kernel( tidx -= 32; } uint32_t q = *reinterpret_cast(&row[tidx * 4 + offset_bytes]); - kv_dq = dequantize_packed_fp8(q, *qparam_src); + + if (Symmetric) { + // No shift, FP32 scale + float* qparam_src_fp32 = reinterpret_cast(qparam_src); + kv_dq = dequantize_packed_fp8_symmetric(q, *qparam_src_fp32); + } else { + kv_dq = dequantize_packed_fp8(q, *qparam_src); + } + // now, write our outputs // each thread writes 4 elements of type bf16 *reinterpret_cast(&row_dq[4 * tidx]) = @@ -275,7 +283,7 @@ __global__ void dequantize_fp8_cache_kernel_paged( } #else -template +template __global__ void dequantize_fp8_cache_kernel( // This code currently represents FP8 version not int4 at::PackedTensorAccessor64 @@ -294,9 +302,9 @@ __global__ void dequantize_fp8_cache_kernel( auto D_H = cache_K_dq.size(3); auto D_H_q = cache_K.size(3); // TODO: support D_H < 128 for small model used in testing. - CUDA_KERNEL_ASSERT(D_H == 128); + // CUDA_KERNEL_ASSERT(D_H == 128); const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4; - CUDA_KERNEL_ASSERT(D_H_q - D_H == offset_bytes); + // CUDA_KERNEL_ASSERT(D_H_q - D_H == offset_bytes); auto b = blockIdx.x; // only need to dequantize this far. @@ -319,24 +327,30 @@ __global__ void dequantize_fp8_cache_kernel( row_k_dq = &cache_K_dq[b][t][h][0]; row_v_dq = &cache_V_dq[b][t][h][0]; // Calculate kv_dq for this row - { - __half2* qparam_k_src; - __half2* qparam_v_src; - if (ExternalQParam) { - size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h; - qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); - qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); - } else { - qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); - qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); - } - uint64_t kq = - *reinterpret_cast(&row_k[threadIdx.x * 4 + offset_bytes]); - uint64_t vq = - *reinterpret_cast(&row_v[threadIdx.x * 4 + offset_bytes]); - - packed = kq | (vq << 32); - + __half2* qparam_k_src; + __half2* qparam_v_src; + if (ExternalQParam) { + size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h; + qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } else { + qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); + qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); + } + uint64_t kq = + *reinterpret_cast(&row_k[threadIdx.x * 4 + offset_bytes]); + uint64_t vq = + *reinterpret_cast(&row_v[threadIdx.x * 4 + offset_bytes]); + + packed = kq | (vq << 32); + + if (symmetric) { + // No shift, FP32 scale + float* qparam_k_src_fp32 = reinterpret_cast(qparam_k_src); + float* qparam_v_src_fp32 = reinterpret_cast(qparam_v_src); + kv_dq = dequantize_packed_fp8_symmetric( + packed, *qparam_k_src_fp32, *qparam_v_src_fp32); + } else { kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src); } // now, write our outputs @@ -482,7 +496,8 @@ std::tuple dequantize_fp8_cache( std::optional qparam_k, std::optional qparam_v, std::optional block_tables, - int64_t page_size) { + int64_t page_size, + std::optional symmetric) { TORCH_CHECK(cache_K.is_cuda()); TORCH_CHECK(cache_V.is_cuda()); TORCH_CHECK(kv_seqlen.is_cuda()); @@ -538,8 +553,9 @@ std::tuple dequantize_fp8_cache( constexpr int32_t kMaxBlocks = 512; dim3 blocks(B, std::max(1, kMaxBlocks / B)); dim3 threads(kThreadsPerWarp, kWarpsPerBlock); -#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM) \ - const auto deq_fn = dequantize_fp8_cache_kernel; \ +#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM, SYMMETRIC_QUANT) \ + const auto deq_fn = \ + dequantize_fp8_cache_kernel; \ deq_fn<<>>( \ cache_K.packed_accessor64(), \ cache_V.packed_accessor64(), \ @@ -549,11 +565,21 @@ std::tuple dequantize_fp8_cache( qparam_k_ptr, \ qparam_v_ptr); \ C10_CUDA_KERNEL_LAUNCH_CHECK() + + bool use_symmetric_quantization = symmetric && symmetric.value(); if (block_tables_ptr == nullptr) { if (qparam_k_ptr) { - CALL_DEQUANTIZE_FP8_CACHE(true); + if (use_symmetric_quantization) { + CALL_DEQUANTIZE_FP8_CACHE(true, true); + } else { + CALL_DEQUANTIZE_FP8_CACHE(true, false); + } } else { - CALL_DEQUANTIZE_FP8_CACHE(false); + if (use_symmetric_quantization) { + CALL_DEQUANTIZE_FP8_CACHE(false, true); + } else { + CALL_DEQUANTIZE_FP8_CACHE(false, false); + } } #undef CALL_DEQUANTIZE_FP8_CACHE } else { diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py index d72e0f6e29..da81562852 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py @@ -357,6 +357,146 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None: cache_v[:, :T], cache_v_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2 ) + @settings(deadline=None) + @given( + MAX_T=st.sampled_from([8000, 16384]), + N_KVH_L=st.sampled_from([1, 2]), + ) + @unittest.skipIf( + not torch.cuda.is_available() + or ( + torch.version.cuda + and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 + ) + or (torch.version.hip and torch.version.hip < "6.2") + or not HAS_XFORMERS, + "Skip when H100 is not available or MI300 is not available", + ) + def test_symmetric_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None: + N_H_L = 2 + T = 2 + B = 2 + D_H = 128 + + xq = ( + torch.cat( + [ + torch.randn(N_H_L, D_H, dtype=torch.bfloat16, device=self.device) + * (i) + for i in range(B * T) + ] + ) + ).view(B * T, N_H_L, D_H) + scale_step = 0.01 / B / T + shift_step = 5 * scale_step + xk_rows = [ + scale_step + * (i + 1) + * torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device) + + i * shift_step + for i in range(B * T) + ] + xv_rows = [ + scale_step + * (i + 1) + * torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device) + + i * shift_step + for i in range(B * T) + ] + + xk = (torch.cat(xk_rows)).view(B * T, N_KVH_L, D_H) + + xv = (torch.cat(xv_rows)).view(B * T, N_KVH_L, D_H) + varseq_seqpos = torch.cat( + [ + torch.as_tensor(list(range(T)), dtype=torch.int, device=self.device) + for b in range(B) + ] + ) + varseq_batch = torch.cat( + [ + torch.as_tensor( + [b for _ in range(T)], dtype=torch.int, device=self.device + ) + for b in range(B) + ] + ) + attn_bias = ( + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[T for _ in range(B)], + kv_padding=MAX_T, + kv_seqlen=[T for _ in range(B)], + ) + ) + attn_bias.k_seqinfo.to(self.device) + assert attn_bias.k_seqinfo.seqlen.shape == (B,) + assert attn_bias.k_seqinfo.seqlen.tolist() == [T for _ in range(B)] + + theta = 10000.0 + cache_k_bf16 = torch.zeros( + size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device + ) + cache_v_bf16 = torch.zeros( + size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device + ) + + xq_out_bf16 = torch.compile( + torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend + )( + xq, + xk, + xv, + cache_k_bf16, + cache_v_bf16, + varseq_batch, + varseq_seqpos, + theta, + ) + qparam_offset = 4 + + cache_k_fp8 = torch.zeros( + size=(B, MAX_T, N_KVH_L, int(D_H) + qparam_offset), + dtype=torch.uint8, + device=self.device, + ) + cache_v_fp8 = torch.zeros( + size=(B, MAX_T, N_KVH_L, int(D_H) + qparam_offset), + dtype=torch.uint8, + device=self.device, + ) + xq_out = torch.compile( + torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend + )( + xq, + xk, + xv, + cache_k_fp8, + cache_v_fp8, + varseq_batch, + varseq_seqpos, + theta, + cache_logical_dtype_int=LogicalDtype.fp8.value, + symmetric_quant=True, + ) + torch.testing.assert_close(xq_out_bf16, xq_out) + + dequantized_cache = torch.compile( + torch.ops.fbgemm.dequantize_fp8_cache, backend=self.compile_backend + )( + cache_k_fp8, + cache_v_fp8, + attn_bias.k_seqinfo.seqlen, + symmetric=True, + ) + cache_k, cache_v = dequantized_cache + + torch.testing.assert_close( + cache_k[:, :T], cache_k_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2 + ) + torch.testing.assert_close( + cache_v[:, :T], cache_v_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2 + ) + @settings(deadline=None) @given( MAX_T=st.sampled_from([8000, 16384]), diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh index e2869e3062..32c03f01d0 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh @@ -415,6 +415,19 @@ DEVICE_INLINE bfx8 dequantize_packed_fp8_symmetric( result.vals[3] = __floats2bfloat162_rn(r3.x, r3.y); return result; } +DEVICE_INLINE bfx4 dequantize_packed_fp8_symmetric( + uint32_t xs, // x0 x1 x2 x3 + float scale) { + __nv_fp8_e4m3* fp8_vs = reinterpret_cast<__nv_fp8_e4m3*>(&xs); // 4 element + + auto r0 = make_float2(float(fp8_vs[0]) * scale, float(fp8_vs[1]) * scale); + auto r1 = make_float2(float(fp8_vs[2]) * scale, float(fp8_vs[3]) * scale); + + bfx4 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r0.y); + result.vals[1] = __floats2bfloat162_rn(r1.x, r1.y); + return result; +} DEVICE_INLINE bfx4 dequantize_packed_fp8(uint32_t vs, __half2 shift_scale_0) { uint32_t v = vs; __nv_fp8_e4m3* fp8_k = reinterpret_cast<__nv_fp8_e4m3*>(&v); // 4 element