Skip to content

Commit

Permalink
new attention kernel using custom matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Apr 22, 2024
1 parent 05b09a6 commit d1c22b5
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,90 @@ __global__ void attention_forward_fused1(float* out, float* preatt, float* att,
}
}

__device__ float4 ld_vec(const float* address) {
return *reinterpret_cast<const float4*>(address);
}

__device__ void st_vec(float* address, float4 val) {
*reinterpret_cast<float4*>(address) = val;
}

__device__ void matmul_tri(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha) {
int i_base = 128 * blockIdx.x + 8 * threadIdx.x;
int j_base = 128 * blockIdx.y + 8 * threadIdx.y;

if (blockIdx.y > blockIdx.x)
return;

k += 128 * blockIdx.x * ks;
q += 128 * blockIdx.y * qs;

__shared__ float lhs_s[128][32];
__shared__ float rhs_s[128][32];

float vals[8][8] = {};
for (int so = 0; so < hs; so += 32) {
__syncthreads();
for(int y = threadIdx.y / 2; y < 128; y += 8) {
int xo = (threadIdx.y % 2) * 16;
lhs_s[y][threadIdx.x + xo] = k[y * ks + so + threadIdx.x + xo];
rhs_s[y][threadIdx.x + xo] = q[y * qs + so + threadIdx.x + xo];
}
__syncthreads();

for (int si = 0; si < 32; ++si) {
float rhs[8];
for (int u = 0; u < 8; ++u) {
rhs[u] = rhs_s[u + 8 * threadIdx.y][(si + threadIdx.x) % 32];
}

for (int ii = 0; ii < 8; ++ii) {
float lhs = lhs_s[ii + 8 * threadIdx.x][(si + threadIdx.x) % 32];
for (int ji = 0; ji < 8; ++ji) {
vals[ii][ji] += lhs * rhs[ji];
}
}
}
}

for (int ii = 0; ii < 8; ++ii) {
for (int ji = 0; ji < 8; ji += 4) {
int i = i_base + ii;
int j = j_base + ji;
float4 result;
result.x = vals[ii][ji + 0] * alpha;
result.y = vals[ii][ji + 1] * alpha;
result.z = vals[ii][ji + 2] * alpha;
result.w = vals[ii][ji + 3] * alpha;
st_vec(p + i * ps + j, result);
}
}
}

template<auto matmul_tri>
__global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float* inp, int T, int C, int NH) {
// skip above the diagonal
if(blockIdx.y > blockIdx.x)
return;

// set up indices
int C3 = C*3;
int hs = C / NH; // head size
float scale = 1.0 / sqrtf(hs);

// we put the "batch x head" dimension into the z block index.
int h = blockIdx.z % NH;
int b = blockIdx.z / NH;

// Get the base address for the current batch and head
const float* q = inp + b * T * C3 + h * hs;
const float* k = inp + b * T * C3 + h * hs + C;
float* r = out + (b*NH + h)*T*T;

// start the multiplication
matmul_tri(r, T, q, C3, k, C3, T, hs, scale);
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -763,6 +847,7 @@ void attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, f
int total_threads = B * NH * T * HS;
int num_blocks = ceil_div(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
cudaCheck(cudaGetLastError());

// batched matrix multiply with cuBLAS
const float alpha = 1.0f;
Expand All @@ -788,6 +873,7 @@ void attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, f
int grid_size = B * NH * T;
size_t shared_mem_size = 2 * softmax_block_size / 32 * sizeof(float);
softmax_forward_kernel4<<<grid_size, softmax_block_size, shared_mem_size>>>(att, preatt, B * NH * T, T);
cudaCheck(cudaGetLastError());

// new approach: first cuBLAS another batched matmul
// y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
Expand All @@ -805,6 +891,7 @@ void attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, f
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = ceil_div(B * T * C, block_size);
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}

void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, float* att,
Expand All @@ -824,6 +911,7 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f
int total_threads = B * NH * T * HS;
int num_blocks = ceil_div(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
cudaCheck(cudaGetLastError());

// batched matrix multiply with cuBLAS
const float alpha = 1.0f;
Expand All @@ -843,6 +931,7 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f
int softmax_block_size = 256;
int grid_size = ceil_div(B * NH * T * 32, softmax_block_size);
softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);
cudaCheck(cudaGetLastError());

// new approach: first cuBLAS another batched matmul
// y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
Expand All @@ -860,6 +949,7 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = ceil_div(B * T * C, block_size);
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}

void attention_forward5(float* out, float* preatt, float* att,
Expand All @@ -869,6 +959,59 @@ void attention_forward5(float* out, float* preatt, float* att,
// attention calculation
int x_blocks = ceil_div(T, block_size / 32);
attention_forward_fused1<<<dim3(x_blocks, NH, B), block_size>>>(out, preatt, att, inp, B, T, C, NH);
cudaCheck(cudaGetLastError());
}

void attention_forward6(float* out, float* vaccum, float* qkvr, float* preatt, float* att,
const float* inp,
int B, int T, int C, int NH,
const int block_size) {
// inp is (B, T, 3C) QKV
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int HS = C / NH; // head size

// permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)
// TODO we don't need q and k anymore, but v is still needed later.
float *q, *k, *v;
q = qkvr + 0 * B * T * C;
k = qkvr + 1 * B * T * C;
v = qkvr + 2 * B * T * C;
int total_threads = B * NH * T * HS;
int num_blocks = ceil_div(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
cudaCheck(cudaGetLastError());


trimul_global<matmul_tri><<<dim3(T / 128, T / 128, NH * B), dim3(16, 16)>>>(preatt, inp, T, C, NH);
cudaCheck(cudaGetLastError());

// multiply all elements of preatt elementwise by scale
float scale = 1.0;
int softmax_block_size = 256;
int grid_size = ceil_div(B * NH * T * 32, softmax_block_size);
softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);
cudaCheck(cudaGetLastError());

// new approach: first cuBLAS another batched matmul
// y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
const float alpha = 1.0f;
const float beta = 0.0f;
cublasCheck(cublasSgemmStridedBatched(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
HS, T, T,
&alpha,
v, HS, T * HS,
att, T, T * T,
&beta,
vaccum, HS, T * HS,
B * NH));

// now unpermute
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = ceil_div(B * T * C, block_size);
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
Expand All @@ -893,6 +1036,9 @@ void attention_forward(int kernel_num,
case 5:
attention_forward5(out, preatt, att, inp, B, T, C, NH, block_size);
break;
case 6:
attention_forward6(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down

0 comments on commit d1c22b5

Please sign in to comment.