Skip to content

Commit 1b200d0

Browse files
committed
fix(kernel): 解决 attention 访存错误的问题
Signed-off-by: YdrMaster <[email protected]>
1 parent 62187a6 commit 1b200d0

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/04kernel/src/kernels/attention/cuda_kernel.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,27 @@ namespace refactor::kernel {
1212
// seqLen: 此次处理的词数
1313
// posId: 在 kv cache 中的位置
1414
// attLen = pastSeqLen + seqLen
15-
static __forceinline__ __device__ bool
16-
causualMask(int tokenId, int seqLen,
17-
int posId, int attLen) {
18-
// tokenId ↓ |<---attLen---->|
19-
// 0 | * * ... * |
20-
// 1 | * * ... * * |
21-
// 2 | * * ... * * * |
22-
// seqLen: 3 |---------------|
23-
return attLen + tokenId >= posId + seqLen;
24-
}
15+
struct AttentionCausualMask {
16+
__forceinline__ __device__ bool
17+
operator()(int tokenId, int seqLen,
18+
int posId, int attLen) {
19+
// tokenId ↓ |<---attLen---->|
20+
// 0 | * * ... * |
21+
// 1 | * * ... * * |
22+
// 2 | * * ... * * * |
23+
// seqLen: 3 |---------------|
24+
return attLen + tokenId >= posId + seqLen;
25+
}
26+
};
2527

2628
// gridDim.x = batch * nHead
2729
// gridDim.y = seqLen
2830
// blockDim.x = min(1024, attLen)
2931
// sizeof(shared) = attLen * sizeof(float)
30-
template<class T>
32+
template<class T, class Mask>
3133
static __global__ void softmax(
3234
T *__restrict__ att,
33-
bool (*mask)(int, int, int, int),
35+
Mask mask,
3436
uint32_t attLen,
3537
uint32_t bufLen) {
3638
// 找到这个线程块对应的 attention 区域
@@ -161,7 +163,7 @@ namespace refactor::kernel {
161163
std::min(1024u, attLen),
162164
attLen * sizeof(float),
163165
stream>>>(
164-
att, causualMask, attLen, bufLen);
166+
att, AttentionCausualMask(), attLen, bufLen);
165167
{
166168
half alpha = 1, beta = 0;
167169
cublasLtMatmul(

src/04kernel/test/kernels/attention/test_cuda.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../../../src/kernels/attention/cuda_kernel.hh"
44
#include "hardware/device_manager.h"
5+
#include "kernel/cuda/functions.cuh"
56
#include <gtest/gtest.h>
67
#include <numeric>
78

@@ -43,6 +44,7 @@ TEST(kernel, AttentionCudaNoKvCache) {
4344
void *outputs[]{*oGpu};
4445
routine(res, *workspace, inputs, outputs);
4546
}
47+
cuda::sync();
4648
}
4749

4850
#endif

0 commit comments

Comments
 (0)