Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ at::Tensor rope_qkv_varseq_prefill_meta(
bool /* k_norm */,
bool /* update_kv */,
std::optional<at::Tensor> /* amax_qkv */,
std::optional<at::Tensor> /* kv_quant_scale_precomputed */
) {
std::optional<at::Tensor> /* kv_quant_scale_precomputed */,
bool /* symmetric_quant */) {
return at::empty_like(XQ);
}

Expand Down Expand Up @@ -95,8 +95,8 @@ at::Tensor rope_qkv_decoding_meta(
std::optional<at::Tensor> /* qparam_v */,
bool /* k_norm */,
bool /* update_kv */,
std::optional<at::Tensor> /* amax_qkv */
) {
std::optional<at::Tensor> /* amax_qkv */,
bool /* symmetric_quant */) {
return at::empty_like(XQ);
}

Expand Down Expand Up @@ -233,7 +233,8 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache_meta(
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> /* qparam_v */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */) {
int64_t /* page_size */,
std::optional<bool> /*symmetric*/) {
const at::SymInt B_KV = cache_K.sym_size(0);
const at::SymInt MAX_T = cache_K.sym_size(1);
const at::SymInt N_KVH = cache_K.sym_size(2);
Expand Down
166 changes: 119 additions & 47 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ enum class PositionEmbeddingMode { ROPE = 0, XPOS = 1, NOPE = 2 };
enum class KVQuantRecipe { perTokenScaling = 0, perHeadScaling = 1 };
enum class QKV { Q, K, V };

template <typename T, KVQuantRecipe recipe = KVQuantRecipe::perTokenScaling>
template <
typename T,
KVQuantRecipe recipe = KVQuantRecipe::perTokenScaling,
bool symmetric = false>
DEVICE_INLINE void quantize_fp8_kv(
fx4 dst,
T* dst_row_q,
Expand Down Expand Up @@ -632,6 +635,7 @@ quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm = false) {
NUM_GROUPS, \
DTYPE, \
EMB_MODE, \
SYMMETRIC_QUANT, \
VARSEQ_BATCH, \
VARSEQ_SEQPOS, \
THETA, \
Expand All @@ -650,46 +654,87 @@ quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm = false) {
hi_freq_factor, \
write_k_back, \
k_norm) \
FBGEMM_LAUNCH_KERNEL( \
(rope_xpos_qkv_varseq_prefill_kernel_quantized< \
EMB_MODE, \
DTYPE, \
NUM_GROUPS>), \
blocks, \
threads, \
0, \
at::cuda::getCurrentCUDAStream(), \
PTA_B(XQ, at::BFloat16, 3, 32), \
PTA_B(XK, at::BFloat16, 3, 32), \
PTA_B(XV, at::BFloat16, 3, 32), \
PTA_B(cache_K, uint8_t, 4, 64), \
PTA_B(cache_V, uint8_t, 4, 64), \
qparam_k_ptr, \
qparam_v_ptr, \
PTA_B(XQ_O, at::BFloat16, 3, 32), \
VARSEQ_BATCH, \
VARSEQ_SEQPOS, \
THETA, \
GAMMA, \
SCALE_BASE, \
EXPO_OFFSET, \
block_tables, \
page_size, \
block_tables_b_stride, \
varseq_cache_seqpos, \
actual_batch_size, \
rope_scaling, \
old_context_len, \
scaling_factor, \
lo_freq_factor, \
hi_freq_factor, \
write_k_back, \
k_norm);
if (SYMMETRIC_QUANT) { \
FBGEMM_LAUNCH_KERNEL( \
(rope_xpos_qkv_varseq_prefill_kernel_quantized< \
EMB_MODE, \
DTYPE, \
NUM_GROUPS, \
true>), \
blocks, \
threads, \
0, \
at::cuda::getCurrentCUDAStream(), \
PTA_B(XQ, at::BFloat16, 3, 32), \
PTA_B(XK, at::BFloat16, 3, 32), \
PTA_B(XV, at::BFloat16, 3, 32), \
PTA_B(cache_K, uint8_t, 4, 64), \
PTA_B(cache_V, uint8_t, 4, 64), \
qparam_k_ptr, \
qparam_v_ptr, \
PTA_B(XQ_O, at::BFloat16, 3, 32), \
VARSEQ_BATCH, \
VARSEQ_SEQPOS, \
THETA, \
GAMMA, \
SCALE_BASE, \
EXPO_OFFSET, \
block_tables, \
page_size, \
block_tables_b_stride, \
varseq_cache_seqpos, \
actual_batch_size, \
rope_scaling, \
old_context_len, \
scaling_factor, \
lo_freq_factor, \
hi_freq_factor, \
write_k_back, \
k_norm); \
} else { \
FBGEMM_LAUNCH_KERNEL( \
(rope_xpos_qkv_varseq_prefill_kernel_quantized< \
EMB_MODE, \
DTYPE, \
NUM_GROUPS, \
false>), \
blocks, \
threads, \
0, \
at::cuda::getCurrentCUDAStream(), \
PTA_B(XQ, at::BFloat16, 3, 32), \
PTA_B(XK, at::BFloat16, 3, 32), \
PTA_B(XV, at::BFloat16, 3, 32), \
PTA_B(cache_K, uint8_t, 4, 64), \
PTA_B(cache_V, uint8_t, 4, 64), \
qparam_k_ptr, \
qparam_v_ptr, \
PTA_B(XQ_O, at::BFloat16, 3, 32), \
VARSEQ_BATCH, \
VARSEQ_SEQPOS, \
THETA, \
GAMMA, \
SCALE_BASE, \
EXPO_OFFSET, \
block_tables, \
page_size, \
block_tables_b_stride, \
varseq_cache_seqpos, \
actual_batch_size, \
rope_scaling, \
old_context_len, \
scaling_factor, \
lo_freq_factor, \
hi_freq_factor, \
write_k_back, \
k_norm); \
}

template <
PositionEmbeddingMode EmbMode,
CacheLogicalDtype kCacheDtype,
int KVQuantNumGroups = 1>
int KVQuantNumGroups = 1,
bool symmetric = false>
__global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
XQ, // [B_T][N_H][D_H]
Expand All @@ -699,10 +744,12 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
XV, // [B_T][N_KVH][D_H]
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
cache_K, // [B][MAX_T][N_KVH][D_H] or
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged
// attention
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
cache_V, // [B][MAX_T][N_KVH][D_H] or
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged
// attention
int32_t* qparam_k_ptr,
int32_t* qparam_v_ptr,
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
Expand Down Expand Up @@ -873,7 +920,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
}
quantize_fp8_kv(dst, dst_row_q, qparam_row);
quantize_fp8_kv<uint8_t, KVQuantRecipe::perTokenScaling, symmetric>(
dst, dst_row_q, qparam_row);
} else if (kCacheDtype == CacheLogicalDtype::INT4) {
CUDA_KERNEL_ASSERT(D_H_q - D_H / 2 == 4 * KVQuantNumGroups);
quantize_int4_kv<KVQuantNumGroups>(dst, dst_row_q);
Expand Down Expand Up @@ -1211,6 +1259,7 @@ at::Tensor nope_qkv_varseq_prefill(
1,
CacheLogicalDtype::FP8,
PositionEmbeddingMode::NOPE,
false, // symmetric_quant is False for FP8 NoPE
varseq_batch_,
varseq_seqpos_,
0,
Expand Down Expand Up @@ -1240,6 +1289,7 @@ at::Tensor nope_qkv_varseq_prefill(
num_groups_,
CacheLogicalDtype::INT4,
PositionEmbeddingMode::NOPE,
false, // symmetric_quant is False for INT4 NoPE
varseq_batch_,
varseq_seqpos_,
0,
Expand Down Expand Up @@ -1406,6 +1456,7 @@ at::Tensor nope_qkv_decoding(
1,
CacheLogicalDtype::FP8,
PositionEmbeddingMode::NOPE,
false, // symmetric_quant is False for FP8 NoPE
batch_,
seqpos_,
0,
Expand Down Expand Up @@ -1435,6 +1486,7 @@ at::Tensor nope_qkv_decoding(
num_groups_,
CacheLogicalDtype::INT4,
PositionEmbeddingMode::NOPE,
false, // symmetric_quant is False for INT4 NoPE
batch_,
seqpos_,
0,
Expand Down Expand Up @@ -1485,7 +1537,8 @@ at::Tensor rope_qkv_varseq_prefill(
bool k_norm = false,
bool update_kv = true,
std::optional<at::Tensor> amax_qkv = std::nullopt,
std::optional<at::Tensor> kv_quant_scale_precomputed = std::nullopt) {
std::optional<at::Tensor> kv_quant_scale_precomputed = std::nullopt,
bool symmetric_quant = false) {
auto B_T = XQ.size(0);
auto N_H = XQ.size(1);
auto N_KVH = 0;
Expand Down Expand Up @@ -1630,6 +1683,7 @@ at::Tensor rope_qkv_varseq_prefill(
1,
CacheLogicalDtype::FP8,
PositionEmbeddingMode::ROPE,
symmetric_quant, // for FP8 RoPE
varseq_batch_,
varseq_seqpos_,
theta,
Expand Down Expand Up @@ -1659,6 +1713,7 @@ at::Tensor rope_qkv_varseq_prefill(
num_groups_,
CacheLogicalDtype::INT4,
PositionEmbeddingMode::ROPE,
false, // symmetric_quant is False for INT4 RoPE
varseq_batch_,
varseq_seqpos_,
theta,
Expand Down Expand Up @@ -1791,6 +1846,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1,
CacheLogicalDtype::FP8,
PositionEmbeddingMode::XPOS,
false, // symmetric_quant is False for FP8 XPOS
varseq_batch_,
varseq_seqpos_,
theta,
Expand Down Expand Up @@ -1820,6 +1876,7 @@ at::Tensor xpos_qkv_varseq_prefill(
num_groups_,
CacheLogicalDtype::INT4,
PositionEmbeddingMode::XPOS,
false, // symmetric_quant is False for INT4 XPOS
varseq_batch_,
varseq_seqpos_,
theta,
Expand Down Expand Up @@ -1869,7 +1926,8 @@ at::Tensor rope_qkv_decoding(
std::optional<at::Tensor> qparam_v = std::nullopt,
bool k_norm = false,
bool update_kv = true,
std::optional<at::Tensor> amax_qkv = std::nullopt) {
std::optional<at::Tensor> amax_qkv = std::nullopt,
bool symmetric_quant = false) {
auto B = XQ.size(0);
auto N_H = XQ.size(1);
auto N_KVH = 0;
Expand Down Expand Up @@ -2002,6 +2060,7 @@ at::Tensor rope_qkv_decoding(
1,
CacheLogicalDtype::FP8,
PositionEmbeddingMode::ROPE,
symmetric_quant, // symmetric_quant is True for FP8 RoPE
nullptr,
seqpos_,
theta,
Expand Down Expand Up @@ -2032,6 +2091,7 @@ at::Tensor rope_qkv_decoding(
num_groups_,
CacheLogicalDtype::INT4,
PositionEmbeddingMode::ROPE,
false, // symmetric_quant is False for INT4 RoPE
nullptr,
seqpos_,
theta,
Expand Down Expand Up @@ -2159,6 +2219,7 @@ at::Tensor xpos_qkv_decoding(
1,
CacheLogicalDtype::FP8,
PositionEmbeddingMode::XPOS,
false, // symmetric_quant is False for FP8 XPOS
nullptr,
seqpos_,
theta,
Expand Down Expand Up @@ -2187,6 +2248,7 @@ at::Tensor xpos_qkv_decoding(
num_groups_,
CacheLogicalDtype::INT4,
PositionEmbeddingMode::XPOS,
false, // symmetric_quant is False for INT4 XPOS
nullptr,
seqpos_,
theta,
Expand Down Expand Up @@ -2233,7 +2295,7 @@ DEVICE_INLINE uint32_t packComponents(uint32_t x_bits[4]) {
return packed;
}

template <typename T, KVQuantRecipe recipe>
template <typename T, KVQuantRecipe recipe, bool symmetric>
DEVICE_INLINE void
quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) {
if (do_norm) {
Expand All @@ -2253,7 +2315,12 @@ quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) {
warp_min = -warpReduceMax(-thread_min, mask);
warp_max = warpReduceMax(thread_max, mask);

auto bounded_max = (warp_max - warp_min) / 2;
float bounded_max;
if (symmetric) {
bounded_max = warp_max;
} else {
bounded_max = (warp_max - warp_min) / 2;
}
// max FP16 value is 65504.0f.
// Divide by 2 to avoid overflow during
// e4m3fn (NV) to e4m3fnuz (AMD) conversion
Expand All @@ -2262,7 +2329,7 @@ quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) {
bounded_max = std::min(bounded_max, scale_ub);
scale = static_cast<float>(
std::max(bounded_max / FP8_E4M3_MAX::value, min_scaling_factor));
shift = warp_min + FP8_E4M3_MAX::value * scale;
shift = symmetric ? 0 : warp_min + FP8_E4M3_MAX::value * scale;
} else {
// Support of per-head scaling is limited to reading a
// pre-calculated scale from qparam tensor and using it for scaling the
Expand Down Expand Up @@ -2291,11 +2358,16 @@ quantize_fp8_kv(fx4 dst, T* dst_row_q, __half2* qparam, bool do_norm) {
param_store = reinterpret_cast<__half2*>(&dst_row_q[0]);
}
CUDA_KERNEL_ASSERT(uintptr_t(param_store) % 4 == 0);
*param_store = __floats2half2_rn(scale, shift);
if (symmetric) {
float* param_store_fp32 = reinterpret_cast<float*>(param_store);
*param_store_fp32 = scale;
} else {
*param_store = __floats2half2_rn(scale, shift);
}
}
}
#else
template <typename T, KVQuantRecipe recipe>
template <typename T, KVQuantRecipe recipe, bool symmetric = false>
DEVICE_INLINE void
quantize_fp8_kv(fx4 dst, T* dst_row_, __half2* qparam, bool do_norm) {}
std::vector<at::Tensor> quantize_fp8_per_tensor(
Expand Down
9 changes: 6 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ at::Tensor rope_qkv_varseq_prefill(
bool k_norm,
bool update_kv,
std::optional<at::Tensor> amax_qkv,
std::optional<at::Tensor> kv_quant_scale_precomputed);
std::optional<at::Tensor> kv_quant_scale_precomputed,
bool symmetric_quant);

at::Tensor rope_qkv_decoding(
at::Tensor XQ,
Expand All @@ -103,7 +104,8 @@ at::Tensor rope_qkv_decoding(
std::optional<at::Tensor> qparam_v,
bool k_norm,
bool update_kv,
std::optional<at::Tensor> amax_qkv);
std::optional<at::Tensor> amax_qkv,
bool symmetric_quant);

at::Tensor xpos_qkv_varseq_prefill(
at::Tensor XQ,
Expand Down Expand Up @@ -172,7 +174,8 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v,
std::optional<at::Tensor> block_tables,
int64_t page_size);
int64_t page_size,
std::optional<bool> symmetric);

at::Tensor quantize_qkv_per_head(
at::Tensor amax,
Expand Down
Loading
Loading