@@ -602,7 +602,8 @@ template <typename T,
602602 int VecSize = 4 ,
603603 int RoundType = 0 ,
604604 int HeadDim = 128 ,
605- bool IsFP8 = false >
605+ bool IsFP8 = false ,
606+ bool IsDynamic = true >
606607__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel (
607608 const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
608609 // gqa_group_size, head_size]
@@ -662,8 +663,6 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
662663 (head_idx - num_heads) % gqa_group_size * block_size +
663664 block_offset;
664665 }
665- T* cache_k_scale_now = cache_k_scale + cache_offset;
666- T* cache_v_scale_now = cache_v_scale + cache_offset;
667666
668667 float thread_m2 = 0 .0f ;
669668 float warp_m2 = 0 .0f ;
@@ -811,25 +810,34 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
811810 }
812811 }
813812 // reduce max, 1 head per warp
814- T local_max = -INFINITY;
813+ if constexpr (IsDynamic) {
814+ T local_max = -INFINITY;
815815#pragma unroll
816- for (int i = 0 ; i < HALF_K_VEC_SIZE; i++) {
817- local_max = __hmax (local_max, __habs (bias_vec1[i]));
818- local_max = __hmax (local_max, __habs (bias_vec2[i]));
819- }
816+ for (int i = 0 ; i < HALF_K_VEC_SIZE; i++) {
817+ local_max = __hmax (local_max, __habs (bias_vec1[i]));
818+ local_max = __hmax (local_max, __habs (bias_vec2[i]));
819+ }
820820#pragma unroll
821- for (int m_offset = 16 ; m_offset > 0 ; m_offset /= 2 ) {
822- local_max =
823- __hmax (local_max, __shfl_xor_sync (0xffffffff , local_max, m_offset));
824- }
825-
826- scale = __hdiv (448 , local_max);
821+ for (int m_offset = 16 ; m_offset > 0 ; m_offset /= 2 ) {
822+ local_max =
823+ __hmax (local_max, __shfl_xor_sync (0xffffffff , local_max, m_offset));
824+ }
827825
828- if (lane_id == 0 ) {
826+ scale = __hdiv (448 , local_max);
827+ T* cache_k_scale_now = cache_k_scale + cache_offset;
828+ T* cache_v_scale_now = cache_v_scale + cache_offset;
829+ if (lane_id == 0 ) {
830+ if (head_idx < num_heads + gqa_group_size) {
831+ cache_k_scale_now[0 ] = __hdiv (1 , scale);
832+ } else {
833+ cache_v_scale_now[0 ] = __hdiv (1 , scale);
834+ }
835+ }
836+ } else {
829837 if (head_idx < num_heads + gqa_group_size) {
830- cache_k_scale_now[ 0 ] = __hdiv ( 1 , scale );
838+ scale = __ldg (&cache_k_scale[kv_head_idx] );
831839 } else {
832- cache_v_scale_now[ 0 ] = __hdiv ( 1 , scale );
840+ scale = __ldg (&cache_v_scale[kv_head_idx] );
833841 }
834842 }
835843
0 commit comments