11#include " ../../utilities/cuda/cublaslt_utils.cuh"
22#include " cuda_kernel.hh"
33#include " hardware/functions.h"
4- #include " kernel/cuda/reduce.cuh"
4+ #include " kernel/cuda/functions.cuh"
5+ #include < cub/block/block_reduce.cuh>
56
67namespace refactor ::kernel {
78 using K = AttentionCuda;
@@ -27,7 +28,7 @@ namespace refactor::kernel {
2728
2829 // gridDim.x = batch * nHead
2930 // gridDim.y = seqLen
30- // blockDim.x = min( 1024, attLen)
31+ // blockDim.x = 1024
3132 // sizeof(shared) = attLen * sizeof(float)
3233 template <class T , class Mask >
3334 static __global__ void softmax (
@@ -36,25 +37,34 @@ namespace refactor::kernel {
3637 uint32_t attLen,
3738 uint32_t bufLen) {
3839 // 找到这个线程块对应的 attention 区域
39- att += (blockIdx .x * gridDim .x + gridDim .y ) * bufLen;
40+ att += (blockIdx .x * gridDim .y + blockIdx .y ) * bufLen;
4041 // 将输入装入共享内存并 cast + mask
4142 extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
4243 for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
4344 shared[i] = mask (blockIdx .y , gridDim .y , i, attLen) ? float (att[i]) : -__FLT_MAX__;
4445 }
4546
47+ using BlockReduce = cub::BlockReduce<float , 1024 >;
48+ __shared__ typename BlockReduce::TempStorage tempStorage;
49+ __shared__ float sharedMax, sharedSum;
50+
4651 float localMax = -1e20 ;
4752 for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
4853 localMax = cub::Max ()(localMax, shared[i]);
4954 }
50- localMax = cuda::blockReduce (localMax, -1e20f, cub::Max ());
55+ localMax = BlockReduce (tempStorage).Reduce (localMax, cub::Max (), attLen);
56+ if (threadIdx .x == 0 ) { sharedMax = localMax; }
57+ __syncthreads ();
5158
5259 float localSum = 1e-20 ;
5360 for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
54- localSum += shared[i] = expf (shared[i] - localMax );
61+ localSum += shared[i] = expf (shared[i] - sharedMax );
5562 }
56- localSum = cuda::blockReduce (localSum, 1e-20f , cub::Sum ());
57- auto reciprocal = fdividef (1 , localSum);
63+ localSum = BlockReduce (tempStorage).Reduce (localSum, cub::Sum (), attLen);
64+ if (threadIdx .x == 0 ) { sharedSum = localSum; }
65+ __syncthreads ();
66+
67+ auto reciprocal = fdividef (1 , sharedSum);
5868 for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
5969 att[i] = shared[i] * reciprocal;
6070 }
0 commit comments