Skip to content

Commit 067aadc

Browse files
authored
Merge pull request #553 from gordicaleksa/refactor_trimat
Refactor trimat
2 parents b117105 + c116fbf commit 067aadc

File tree

1 file changed

+123
-65
lines changed

1 file changed

+123
-65
lines changed

dev/cuda/trimat_forward.cu

+123-65
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,38 @@ static float* d_qkvr; // scratch for the cublas kernel
6565
// taken from then attention forward pass
6666
void trimul_cpu(float* out, const float* inp,
6767
int B, int T, int C, int NH) {
68+
// inp shape: (B, T, 3, NH, HS)
69+
// out shape: (B, NH, T, T)
6870
int C3 = C*3;
69-
int hs = C / NH; // head size
70-
float scale = 1.0 / sqrtf(hs);
71+
int HS = C / NH; // head size
72+
float scale = 1.0 / sqrtf(HS);
7173

7274
for (int b = 0; b < B; b++) {
7375
for (int t = 0; t < T; t++) {
74-
for (int h = 0; h < NH; h++) {
75-
const float* query_t = inp + b * T * C3 + t * C3 + h * hs;
76-
float* out_bth = out + b * NH * T * T + h * T * T + t * T;
76+
for (int nh = 0; nh < NH; nh++) {
77+
// Q[b][nh][t][:] = inp[b][t][0][nh][:] (where : is the slice operator for hs)
78+
const float* query_t = inp + b * T * C3 + t * C3 + nh * HS;
79+
// out[b][nh][t][:]
80+
float* out_bth = out + b * NH * T * T + nh * T * T + t * T;
7781

7882
// pass 1: calculate query dot key and maxval
7983
for (int t2 = 0; t2 <= t; t2++) {
80-
const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
84+
// K[b][nh][t2][:] = inp[b][t2][1][nh][:]
85+
const float* key_t2 = inp + b * T * C3 + t2 * C3 + nh * HS + C; // +C because it's key
8186

82-
// (query_t) dot (key_t2)
87+
// Q[b][nh][t][:] dot K[b][nh][t2][:]
8388
float val = 0.0f;
84-
for (int i = 0; i < hs; i++) {
89+
for (int i = 0; i < HS; i++) {
8590
val += query_t[i] * key_t2[i];
8691
}
8792
val *= scale;
8893

94+
// out[b][nh][t][t2] = val
8995
out_bth[t2] = val;
9096
}
9197
for(int t2 = t + 1; t2 < T; ++t2) {
98+
// causal mask, using NAN to supress warnings -> it could be -inf
99+
// but it doesn't matter because in validate_result we ignore infinities/NANs
92100
out_bth[t2] = NAN;
93101
}
94102
}
@@ -98,31 +106,31 @@ void trimul_cpu(float* out, const float* inp,
98106

99107
__global__ void permute_kernel(float* q, float* k, float* v,
100108
const float* inp,
101-
int B, int N, int NH, int d) {
102-
// okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)
103-
// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)
109+
int B, int T, int NH, int HS) {
110+
// okay so now, this kernel wants Q,K,V to all be of shape (B, NH, T, HS)
111+
// but instead, we have a single tensor QKV (inp) of shape (B, T, 3, NH, HS)
104112
int idx = blockIdx.x * blockDim.x + threadIdx.x;
105113

106-
// Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]
114+
// Q[b][nh][t][hs] = inp[b][t][0][nh][hs]
107115

108-
if (idx < B * NH * N * d) {
109-
int b = idx / (NH * N * d);
110-
int rest = idx % (NH * N * d);
111-
int nh_ = rest / (N * d);
112-
rest = rest % (N * d);
113-
int n = rest / d;
114-
int d_ = rest % d;
116+
if (idx < B * NH * T * HS) {
117+
int b = idx / (NH * T * HS);
118+
int rest = idx % (NH * T * HS);
119+
int nh = rest / (T * HS);
120+
rest = rest % (T * HS);
121+
int t = rest / HS;
122+
int hs = rest % HS;
115123

116124
int inp_idx = \
117-
(b * N * 3 * NH * d)
118-
+ (n * 3 * NH * d)
119-
+ (0 * NH * d)
120-
+ (nh_ * d)
121-
+ d_;
125+
(b * T * 3 * NH * HS)
126+
+ (t * 3 * NH * HS)
127+
+ (0 * NH * HS)
128+
+ (nh * HS)
129+
+ hs;
122130

123131
q[idx] = inp[inp_idx];
124-
k[idx] = inp[inp_idx + NH * d];
125-
v[idx] = inp[inp_idx + 2 * (NH * d)];
132+
k[idx] = inp[inp_idx + NH * HS];
133+
v[idx] = inp[inp_idx + 2 * (NH * HS)];
126134
}
127135
}
128136

@@ -145,6 +153,35 @@ void trimul_cublas(float* preatt,
145153
// batched matrix multiply with cuBLAS
146154
const float alpha = 1.0f / sqrtf(HS);
147155
const float beta = 0.0f;
156+
// This schedules in parallel B*NH matmuls of shape q@k^t = (T, HS) @ (HS, T) = (T, T).
157+
// IMPORTANT NOTE: Cublas uses a column-major (and we use row-major in our codebase) representation,
158+
// so this call might look confusing to you if you look at the `cublasSgemmStridedBatched` signature.
159+
//
160+
// In order to avoid having to do an additional transpose operation after this func call,
161+
// we need to pass in K as the first argument and Q as the second argument, which might make you think we're computing K^T @ Q.
162+
// That combined with the shapes we got after the permute kernel - (B, NH, T, HS) (I'll omit B, NH for brevity going forward)
163+
// and you might think we end up with (HS, T) @ (T, HS) = (HS, HS).
164+
// This is not the case. :)
165+
//
166+
// Cublas sees our row-major matrix (T, HS) as (HS, T), hence we set the lead dimensions to HS (see function signature).
167+
// We transpose K and end up computing K^T @ Q = (T, HS) @ (HS, T) = (T, T).
168+
// If you were to interpret the above formula K^T @ Q you might think we end up with:
169+
// -----------------------------------
170+
// k1.dot(q1) k1.dot(q2) ... k1.dot(qT)
171+
// k2.dot(q1) k2.dot(q2) ... k2.dot(qT)
172+
// ...
173+
// kT.dot(q1) kT.dot(q2) ... kT.dot(qT)
174+
// -----------------------------------
175+
// But as I mentioned, Cublas is column-major!
176+
// So given that the dot product is symmetric we can write k1.dot(q1) as q1.dot(k1) and transposing the above
177+
// representation we can see what we actually end up with in the row-major format:
178+
// -----------------------------------
179+
// q1.dot(k1) q1.dot(k2) ... q1.dot(kT)
180+
// q2.dot(k1) q2.dot(k2) ... q2.dot(kT)
181+
// ...
182+
// qT.dot(k1) qT.dot(k2) ... qT.dot(kT)
183+
// -----------------------------------
184+
// which is exactly what we wanted! :)
148185
cublasCheck(cublasSgemmStridedBatched(cublas_handle,
149186
CUBLAS_OP_T, CUBLAS_OP_N,
150187
T, T, HS,
@@ -173,7 +210,7 @@ void trimul_cublas(float* preatt,
173210
*/
174211

175212
// using creates an alias for a function pointer
176-
using matmul_fn_ptr = void(*)(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha);
213+
using matmul_fn_ptr = void(*)(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha);
177214

178215
template<matmul_fn_ptr matmul_tri>
179216
__global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float* inp, int T, int C, int NH) {
@@ -183,20 +220,21 @@ __global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float*
183220

184221
// set up indices
185222
int C3 = C*3;
186-
int hs = C / NH; // head size
187-
float scale = 1.0 / sqrtf(hs);
223+
int HS = C / NH; // head size
224+
float scale = 1.0 / sqrtf(HS);
188225

189226
// we put the "batch x head" dimension into the z block index.
190-
int h = blockIdx.z % NH;
191227
int b = blockIdx.z / NH;
228+
int nh = blockIdx.z % NH;
192229

193230
// Get the base address for the current batch and head
194-
const float* q = inp + b * T * C3 + h * hs;
195-
const float* k = inp + b * T * C3 + h * hs + C;
196-
float* r = out + (b*NH + h)*T*T;
231+
// shapes -> inp (B, T, 3, NH, HS), Q (B, NH, T, HS), K (B, NH, T, HS)
232+
const float* q = inp + b * T * C3 + nh * HS; // Q[b][nh][:][:] = inp[b][:][0][nh][:]
233+
const float* k = inp + b * T * C3 + nh * HS + C; // K[b][nh][:][:] = inp[b][:][1][nh][:]
234+
float* r = out + (b*NH + nh)*T*T; // out[b][nh][:][:]
197235

198236
// start the multiplication
199-
matmul_tri(r, T, q, C3, k, C3, T, hs, scale);
237+
matmul_tri(r, T, k, C3, q, C3, T, HS, scale);
200238
}
201239

202240
template<matmul_fn_ptr matmul_tri>
@@ -239,12 +277,22 @@ void trimul_launcher(float* out, const float* inp, int B, int T, int C, int NH)
239277
*/
240278

241279
// baseline implementation: 20 ms
242-
__device__ void matmul_tri_naive(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha) {
243-
// get coordinates of our block
280+
__device__ void matmul_tri_naive(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {
281+
// coordinate system:
282+
// | - - - - - > j
283+
// |
284+
// |
285+
// v
286+
// i
287+
// get coordinates of our block - each thread is responsible for a single 8x8 block.
244288
int i_base = 128 * blockIdx.x + 8 * threadIdx.x;
245289
int j_base = 128 * blockIdx.y + 8 * threadIdx.y;
246290

247-
// one more check to skip the upper diagonal in blocks that are on the diagonal.
291+
// One more check to skip the upper diagonal in blocks that are on the diagonal.
292+
// Note: we deliberately waste some compute on the jagged diagonal i.e. elements that belong
293+
// to the upper triangle that should be masked out. This will be ignored due to the causal mask
294+
// in the reference CPU implementation when used in the `validate_result` function.
295+
// Alternatively this check should be done in the nested for loop below -> if (i > j) return.
248296
if(j_base > i_base)
249297
return;
250298

@@ -254,17 +302,17 @@ __device__ void matmul_tri_naive(float* p, int ps, const float* k, int ks, const
254302
for(int jo = 0; jo < 8; ++jo) {
255303
int j = j_base + jo;
256304
float val = 0;
257-
for (int s = 0; s < hs; ++s) {
258-
val += k[i * ks + s] * q[j * qs + s];
305+
for (int s = 0; s < HS; ++s) {
306+
val += q[i * QS + s] * k[j * KS + s];
259307
}
260-
p[i * ps + j] = val * alpha;
308+
p[i * PS + j] = val * alpha;
261309
}
262310
}
263311
}
264312

265313
/* ** Chapter IV - ... **
266314
*
267-
* Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send there runners of 64 times
315+
* Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send their runners 64 times
268316
* to fetch the corresponding shapes. This is terribly inefficient; The runners need a minute or so for each trip,
269317
* but making a cookie can be done in just a second.
270318
*
@@ -292,25 +340,25 @@ __device__ void matmul_tri_naive(float* p, int ps, const float* k, int ks, const
292340
*/
293341

294342
// reorganize loops to enable data reuse: 3.5 ms
295-
__device__ void matmul_tri_registers(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha) {
343+
__device__ void matmul_tri_registers(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {
296344
int i_base = 128 * blockIdx.x + 8 * threadIdx.x;
297345
int j_base = 128 * blockIdx.y + 8 * threadIdx.y;
298346

299347
if (j_base > i_base)
300348
return;
301349

302350
// shift our pointers to the sub-block this thread is responsible for
303-
k += i_base * ks;
304-
q += j_base * qs;
305-
p += i_base * ps + j_base;
351+
q += i_base * QS;
352+
k += j_base * KS;
353+
p += i_base * PS + j_base;
306354

307355
float vals[8][8] = {};
308-
for (int s = 0; s < hs; ++s) {
356+
for (int hs = 0; hs < HS; ++hs) {
309357
float lhs[8];
310358
float rhs[8];
311359
for (int u = 0; u < 8; ++u) {
312-
lhs[u] = k[u * ks + s];
313-
rhs[u] = q[u * qs + s];
360+
lhs[u] = q[u * QS + hs];
361+
rhs[u] = k[u * KS + hs];
314362
}
315363

316364
for (int i = 0; i < 8; ++i) {
@@ -322,7 +370,7 @@ __device__ void matmul_tri_registers(float* p, int ps, const float* k, int ks, c
322370

323371
for (int i = 0; i < 8; ++i) {
324372
for (int j = 0; j < 8; ++j) {
325-
p[i * ps + j] = vals[i][j] * alpha;
373+
p[i * PS + j] = vals[i][j] * alpha;
326374
}
327375
}
328376
}
@@ -334,7 +382,7 @@ __device__ void matmul_tri_registers(float* p, int ps, const float* k, int ks, c
334382
* "Of course", the runner answers, "but they've asked me for an elephant, a lion, a zebra, and a goldfish. These
335383
* are all over the place, I can't just pick them up at one spot (_strided acccess_).
336384
* "But the lion is right next to the palm tree. You could bring those two together?", you confirm.
337-
* "Yes", he says, "if the just asked for the different categories at the same time, that would make things
385+
* "Yes", he says, "if they just asked for the different categories at the same time, that would make things
338386
* so much easier. See, I have this bucket, I could carry lots of things in one go if I could just scoop them up
339387
* from the same place (_coalesced access_).
340388
*
@@ -364,29 +412,30 @@ __device__ void st_vec(float* address, float4 val) {
364412
}
365413

366414
// vector instructions for coalesced memory access: 1.7 ms
367-
__device__ void matmul_tri3(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha) {
415+
__device__ void matmul_tri3(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {
416+
// Same logic as previous kernel we just load in float4 to improve coalescing
368417
int i_base = 128 * blockIdx.x + 8 * threadIdx.x;
369418
int j_base = 128 * blockIdx.y + 8 * threadIdx.y;
370419

371420
if (j_base > i_base)
372421
return;
373422

374423
// shift our pointers to the sub-block this thread is responsible for
375-
k += i_base * ks;
376-
q += j_base * qs;
377-
p += i_base * ps + j_base;
424+
q += i_base * QS;
425+
k += j_base * KS;
426+
p += i_base * PS + j_base;
378427

379428
float vals[8][8] = {};
380-
for (int s = 0; s < hs; s += 4) {
429+
for (int hs = 0; hs < HS; hs += 4) {
381430
// load in float4 to improve coalescing
382431
float4 rhs[8];
383432
for (int u = 0; u < 8; ++u) {
384-
rhs[u] = ld_vec(q + u * qs + s);
433+
rhs[u] = ld_vec(k + u * KS + hs);
385434
}
386435

387436
for (int i = 0; i < 8; ++i) {
388437
// no need to keep lhs around for the i loop, its only reused in the j loop anyway.
389-
float4 lhs = ld_vec(k + i * ks + s);
438+
float4 lhs = ld_vec(q + i * QS + hs);
390439
for (int j = 0; j < 8; ++j) {
391440
vals[i][j] += lhs.x * rhs[j].x;
392441
vals[i][j] += lhs.y * rhs[j].y;
@@ -403,7 +452,7 @@ __device__ void matmul_tri3(float* p, int ps, const float* k, int ks, const floa
403452
result.y = vals[i][j + 1] * alpha;
404453
result.z = vals[i][j + 2] * alpha;
405454
result.w = vals[i][j + 3] * alpha;
406-
st_vec(p + i * ps + j, result);
455+
st_vec(p + i * PS + j, result);
407456
}
408457
}
409458
}
@@ -424,7 +473,7 @@ __device__ void matmul_tri3(float* p, int ps, const float* k, int ks, const floa
424473
* details.]
425474
*
426475
*/
427-
__device__ void matmul_tri4(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha) {
476+
__device__ void matmul_tri4(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {
428477
int i_base = 128 * blockIdx.x + 8 * threadIdx.x;
429478
int j_base = 128 * blockIdx.y + 8 * threadIdx.y;
430479

@@ -433,29 +482,38 @@ __device__ void matmul_tri4(float* p, int ps, const float* k, int ks, const floa
433482
if (blockIdx.y > blockIdx.x)
434483
return;
435484

436-
k += 128 * blockIdx.x * ks;
437-
q += 128 * blockIdx.y * qs;
485+
q += 128 * blockIdx.x * QS;
486+
k += 128 * blockIdx.y * KS;
438487

439488
__shared__ float lhs_s[128][32];
440489
__shared__ float rhs_s[128][32];
441490

442491
float vals[8][8] = {};
443-
for (int so = 0; so < hs; so += 32) {
492+
for (int so = 0; so < HS; so += 32) {
444493
// Read a large slice of the input, worked on together by all threads.
445494
// They are organized differently for this part. We want to ensure
446495
// fully coalesced loads, so we let a single warp handle consecutive
447496
// addresses, which means we need to combine two threadIdx.y values
448497
// in one read operation.
449498
// note: threads may read data here that they don't need themselves.
450499
// this really is a block-level operation.
500+
// note2: 16x16 threads (i.e. the block) will, through this for loop, fetch 32 dims from 128 keys and 128 queries
501+
// i.e. from Q/K, of shape (T, HS) take q[:128, so*32:(so+1)*32] and k[:128, so*32:(so+1)*32]
451502
__syncthreads();
452503
for(int y = threadIdx.y / 2; y < 128; y += 8) {
453504
int xo = (threadIdx.y % 2) * 16;
454-
lhs_s[y][threadIdx.x + xo] = k[y * ks + so + threadIdx.x + xo];
455-
rhs_s[y][threadIdx.x + xo] = q[y * qs + so + threadIdx.x + xo];
505+
lhs_s[y][threadIdx.x + xo] = q[y * QS + so + threadIdx.x + xo];
506+
rhs_s[y][threadIdx.x + xo] = k[y * KS + so + threadIdx.x + xo];
456507
}
457508
__syncthreads();
458509

510+
// Now we compute a partial dot product (only 32 dims) for all combinations of keys and queries (128x128).
511+
// Each thread does 8x8 of these partial dot products.
512+
// E.g. thread (0,0) covers queries 0-7 and keys 0-7. More generally first row of threads
513+
// (0,:) covers queries 0-7 with keys 0-127 and so on.
514+
// In the next iterations of the outer (`so`) loop we'll be accumulating values to `vals` until we
515+
// get the full dot product. We then later deposit it into the output matrix for all 8x8 blocks
516+
// that are below the diagonal.
459517
for (int si = 0; si < 32; ++si) {
460518
float rhs[8];
461519
for (int u = 0; u < 8; ++u) {
@@ -484,7 +542,7 @@ __device__ void matmul_tri4(float* p, int ps, const float* k, int ks, const floa
484542
result.y = vals[ii][ji + 1] * alpha;
485543
result.z = vals[ii][ji + 2] * alpha;
486544
result.w = vals[ii][ji + 3] * alpha;
487-
st_vec(p + i * ps + j, result);
545+
st_vec(p + i * PS + j, result);
488546
}
489547
}
490548
}

0 commit comments

Comments
 (0)