Skip to content

Commit 2d1dade

Browse files
authored
[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)
* support static cachekv c8 quantization in mtp mode * optimize memory allocation
1 parent 3c36283 commit 2d1dade

File tree

6 files changed

+350
-295
lines changed

6 files changed

+350
-295
lines changed

custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,8 @@ template <typename T,
602602
int VecSize = 4,
603603
int RoundType = 0,
604604
int HeadDim = 128,
605-
bool IsFP8 = false>
605+
bool IsFP8 = false,
606+
bool IsDynamic = true>
606607
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
607608
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
608609
// gqa_group_size, head_size]
@@ -662,8 +663,6 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
662663
(head_idx - num_heads) % gqa_group_size * block_size +
663664
block_offset;
664665
}
665-
T* cache_k_scale_now = cache_k_scale + cache_offset;
666-
T* cache_v_scale_now = cache_v_scale + cache_offset;
667666

668667
float thread_m2 = 0.0f;
669668
float warp_m2 = 0.0f;
@@ -811,25 +810,34 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
811810
}
812811
}
813812
// reduce max, 1 head per warp
814-
T local_max = -INFINITY;
813+
if constexpr (IsDynamic) {
814+
T local_max = -INFINITY;
815815
#pragma unroll
816-
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
817-
local_max = __hmax(local_max, __habs(bias_vec1[i]));
818-
local_max = __hmax(local_max, __habs(bias_vec2[i]));
819-
}
816+
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
817+
local_max = __hmax(local_max, __habs(bias_vec1[i]));
818+
local_max = __hmax(local_max, __habs(bias_vec2[i]));
819+
}
820820
#pragma unroll
821-
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
822-
local_max =
823-
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
824-
}
825-
826-
scale = __hdiv(448, local_max);
821+
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
822+
local_max =
823+
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
824+
}
827825

828-
if (lane_id == 0) {
826+
scale = __hdiv(448, local_max);
827+
T* cache_k_scale_now = cache_k_scale + cache_offset;
828+
T* cache_v_scale_now = cache_v_scale + cache_offset;
829+
if (lane_id == 0) {
830+
if (head_idx < num_heads + gqa_group_size) {
831+
cache_k_scale_now[0] = __hdiv(1, scale);
832+
} else {
833+
cache_v_scale_now[0] = __hdiv(1, scale);
834+
}
835+
}
836+
} else {
829837
if (head_idx < num_heads + gqa_group_size) {
830-
cache_k_scale_now[0] = __hdiv(1, scale);
838+
scale = __ldg(&cache_k_scale[kv_head_idx]);
831839
} else {
832-
cache_v_scale_now[0] = __hdiv(1, scale);
840+
scale = __ldg(&cache_v_scale[kv_head_idx]);
833841
}
834842
}
835843

0 commit comments

Comments
 (0)