@@ -12,25 +12,27 @@ namespace refactor::kernel {
1212 // seqLen: 此次处理的词数
1313 // posId: 在 kv cache 中的位置
1414 // attLen = pastSeqLen + seqLen
15- static __forceinline__ __device__ bool
16- causualMask (int tokenId, int seqLen,
17- int posId, int attLen) {
18- // tokenId ↓ |<---attLen---->|
19- // 0 | * * ... * |
20- // 1 | * * ... * * |
21- // 2 | * * ... * * * |
22- // seqLen: 3 |---------------|
23- return attLen + tokenId >= posId + seqLen;
24- }
15+ struct AttentionCausualMask {
16+ __forceinline__ __device__ bool
17+ operator ()(int tokenId, int seqLen,
18+ int posId, int attLen) {
19+ // tokenId ↓ |<---attLen---->|
20+ // 0 | * * ... * |
21+ // 1 | * * ... * * |
22+ // 2 | * * ... * * * |
23+ // seqLen: 3 |---------------|
24+ return attLen + tokenId >= posId + seqLen;
25+ }
26+ };
2527
2628 // gridDim.x = batch * nHead
2729 // gridDim.y = seqLen
2830 // blockDim.x = min(1024, attLen)
2931 // sizeof(shared) = attLen * sizeof(float)
30- template <class T >
32+ template <class T , class Mask >
3133 static __global__ void softmax (
3234 T *__restrict__ att,
33- bool (* mask)( int , int , int , int ) ,
35+ Mask mask,
3436 uint32_t attLen,
3537 uint32_t bufLen) {
3638 // 找到这个线程块对应的 attention 区域
@@ -161,7 +163,7 @@ namespace refactor::kernel {
161163 std::min (1024u , attLen),
162164 attLen * sizeof(float ),
163165 stream>>>(
164- att, causualMask , attLen, bufLen);
166+ att, AttentionCausualMask() , attLen, bufLen);
165167 {
166168 half alpha = 1 , beta = 0 ;
167169 cublasLtMatmul (
0 commit comments