@@ -65,30 +65,38 @@ static float* d_qkvr; // scratch for the cublas kernel
65
65
// taken from then attention forward pass
66
66
void trimul_cpu (float * out, const float * inp,
67
67
int B, int T, int C, int NH) {
68
+ // inp shape: (B, T, 3, NH, HS)
69
+ // out shape: (B, NH, T, T)
68
70
int C3 = C*3 ;
69
- int hs = C / NH; // head size
70
- float scale = 1.0 / sqrtf (hs );
71
+ int HS = C / NH; // head size
72
+ float scale = 1.0 / sqrtf (HS );
71
73
72
74
for (int b = 0 ; b < B; b++) {
73
75
for (int t = 0 ; t < T; t++) {
74
- for (int h = 0 ; h < NH; h++) {
75
- const float * query_t = inp + b * T * C3 + t * C3 + h * hs;
76
- float * out_bth = out + b * NH * T * T + h * T * T + t * T;
76
+ for (int nh = 0 ; nh < NH; nh++) {
77
+ // Q[b][nh][t][:] = inp[b][t][0][nh][:] (where : is the slice operator for hs)
78
+ const float * query_t = inp + b * T * C3 + t * C3 + nh * HS;
79
+ // out[b][nh][t][:]
80
+ float * out_bth = out + b * NH * T * T + nh * T * T + t * T;
77
81
78
82
// pass 1: calculate query dot key and maxval
79
83
for (int t2 = 0 ; t2 <= t; t2++) {
80
- const float * key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
84
+ // K[b][nh][t2][:] = inp[b][t2][1][nh][:]
85
+ const float * key_t2 = inp + b * T * C3 + t2 * C3 + nh * HS + C; // +C because it's key
81
86
82
- // (query_t) dot (key_t2)
87
+ // Q[b][nh][t][:] dot K[b][nh][t2][:]
83
88
float val = 0 .0f ;
84
- for (int i = 0 ; i < hs ; i++) {
89
+ for (int i = 0 ; i < HS ; i++) {
85
90
val += query_t [i] * key_t2[i];
86
91
}
87
92
val *= scale;
88
93
94
+ // out[b][nh][t][t2] = val
89
95
out_bth[t2] = val;
90
96
}
91
97
for (int t2 = t + 1 ; t2 < T; ++t2) {
98
+ // causal mask, using NAN to supress warnings -> it could be -inf
99
+ // but it doesn't matter because in validate_result we ignore infinities/NANs
92
100
out_bth[t2] = NAN;
93
101
}
94
102
}
@@ -98,31 +106,31 @@ void trimul_cpu(float* out, const float* inp,
98
106
99
107
__global__ void permute_kernel (float * q, float * k, float * v,
100
108
const float * inp,
101
- int B, int N , int NH, int d ) {
102
- // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d )
103
- // but instead, we have a single tensor QKV (inp) of shape (B, N , 3, NH, d )
109
+ int B, int T , int NH, int HS ) {
110
+ // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, T, HS )
111
+ // but instead, we have a single tensor QKV (inp) of shape (B, T , 3, NH, HS )
104
112
int idx = blockIdx .x * blockDim .x + threadIdx .x ;
105
113
106
- // Q[b][nh_][n][d_ ] = inp[b][n ][0][nh_][d_ ]
114
+ // Q[b][nh][t][hs ] = inp[b][t ][0][nh][hs ]
107
115
108
- if (idx < B * NH * N * d ) {
109
- int b = idx / (NH * N * d );
110
- int rest = idx % (NH * N * d );
111
- int nh_ = rest / (N * d );
112
- rest = rest % (N * d );
113
- int n = rest / d ;
114
- int d_ = rest % d ;
116
+ if (idx < B * NH * T * HS ) {
117
+ int b = idx / (NH * T * HS );
118
+ int rest = idx % (NH * T * HS );
119
+ int nh = rest / (T * HS );
120
+ rest = rest % (T * HS );
121
+ int t = rest / HS ;
122
+ int hs = rest % HS ;
115
123
116
124
int inp_idx = \
117
- (b * N * 3 * NH * d )
118
- + (n * 3 * NH * d )
119
- + (0 * NH * d )
120
- + (nh_ * d )
121
- + d_ ;
125
+ (b * T * 3 * NH * HS )
126
+ + (t * 3 * NH * HS )
127
+ + (0 * NH * HS )
128
+ + (nh * HS )
129
+ + hs ;
122
130
123
131
q[idx] = inp[inp_idx];
124
- k[idx] = inp[inp_idx + NH * d ];
125
- v[idx] = inp[inp_idx + 2 * (NH * d )];
132
+ k[idx] = inp[inp_idx + NH * HS ];
133
+ v[idx] = inp[inp_idx + 2 * (NH * HS )];
126
134
}
127
135
}
128
136
@@ -145,6 +153,35 @@ void trimul_cublas(float* preatt,
145
153
// batched matrix multiply with cuBLAS
146
154
const float alpha = 1 .0f / sqrtf (HS);
147
155
const float beta = 0 .0f ;
156
+ // This schedules in parallel B*NH matmuls of shape q@k^t = (T, HS) @ (HS, T) = (T, T).
157
+ // IMPORTANT NOTE: Cublas uses a column-major (and we use row-major in our codebase) representation,
158
+ // so this call might look confusing to you if you look at the `cublasSgemmStridedBatched` signature.
159
+ //
160
+ // In order to avoid having to do an additional transpose operation after this func call,
161
+ // we need to pass in K as the first argument and Q as the second argument, which might make you think we're computing K^T @ Q.
162
+ // That combined with the shapes we got after the permute kernel - (B, NH, T, HS) (I'll omit B, NH for brevity going forward)
163
+ // and you might think we end up with (HS, T) @ (T, HS) = (HS, HS).
164
+ // This is not the case. :)
165
+ //
166
+ // Cublas sees our row-major matrix (T, HS) as (HS, T), hence we set the lead dimensions to HS (see function signature).
167
+ // We transpose K and end up computing K^T @ Q = (T, HS) @ (HS, T) = (T, T).
168
+ // If you were to interpret the above formula K^T @ Q you might think we end up with:
169
+ // -----------------------------------
170
+ // k1.dot(q1) k1.dot(q2) ... k1.dot(qT)
171
+ // k2.dot(q1) k2.dot(q2) ... k2.dot(qT)
172
+ // ...
173
+ // kT.dot(q1) kT.dot(q2) ... kT.dot(qT)
174
+ // -----------------------------------
175
+ // But as I mentioned, Cublas is column-major!
176
+ // So given that the dot product is symmetric we can write k1.dot(q1) as q1.dot(k1) and transposing the above
177
+ // representation we can see what we actually end up with in the row-major format:
178
+ // -----------------------------------
179
+ // q1.dot(k1) q1.dot(k2) ... q1.dot(kT)
180
+ // q2.dot(k1) q2.dot(k2) ... q2.dot(kT)
181
+ // ...
182
+ // qT.dot(k1) qT.dot(k2) ... qT.dot(kT)
183
+ // -----------------------------------
184
+ // which is exactly what we wanted! :)
148
185
cublasCheck (cublasSgemmStridedBatched (cublas_handle,
149
186
CUBLAS_OP_T, CUBLAS_OP_N,
150
187
T, T, HS,
@@ -173,7 +210,7 @@ void trimul_cublas(float* preatt,
173
210
*/
174
211
175
212
// using creates an alias for a function pointer
176
- using matmul_fn_ptr = void (*)(float * p, int ps , const float * k, int ks , const float * q, int qs , int T, int hs , float alpha);
213
+ using matmul_fn_ptr = void (*)(float * p, int PS , const float * k, int KS , const float * q, int QS , int T, int HS , float alpha);
177
214
178
215
template <matmul_fn_ptr matmul_tri>
179
216
__global__ void __launch_bounds__ (256 , 2 ) trimul_global(float * out, const float * inp, int T, int C, int NH) {
@@ -183,20 +220,21 @@ __global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float*
183
220
184
221
// set up indices
185
222
int C3 = C*3 ;
186
- int hs = C / NH; // head size
187
- float scale = 1.0 / sqrtf (hs );
223
+ int HS = C / NH; // head size
224
+ float scale = 1.0 / sqrtf (HS );
188
225
189
226
// we put the "batch x head" dimension into the z block index.
190
- int h = blockIdx .z % NH;
191
227
int b = blockIdx .z / NH;
228
+ int nh = blockIdx .z % NH;
192
229
193
230
// Get the base address for the current batch and head
194
- const float * q = inp + b * T * C3 + h * hs;
195
- const float * k = inp + b * T * C3 + h * hs + C;
196
- float * r = out + (b*NH + h)*T*T;
231
+ // shapes -> inp (B, T, 3, NH, HS), Q (B, NH, T, HS), K (B, NH, T, HS)
232
+ const float * q = inp + b * T * C3 + nh * HS; // Q[b][nh][:][:] = inp[b][:][0][nh][:]
233
+ const float * k = inp + b * T * C3 + nh * HS + C; // K[b][nh][:][:] = inp[b][:][1][nh][:]
234
+ float * r = out + (b*NH + nh)*T*T; // out[b][nh][:][:]
197
235
198
236
// start the multiplication
199
- matmul_tri (r, T, q , C3, k , C3, T, hs , scale);
237
+ matmul_tri (r, T, k , C3, q , C3, T, HS , scale);
200
238
}
201
239
202
240
template <matmul_fn_ptr matmul_tri>
@@ -239,12 +277,22 @@ void trimul_launcher(float* out, const float* inp, int B, int T, int C, int NH)
239
277
*/
240
278
241
279
// baseline implementation: 20 ms
242
- __device__ void matmul_tri_naive (float * p, int ps, const float * k, int ks, const float * q, int qs, int T, int hs, float alpha) {
243
- // get coordinates of our block
280
+ __device__ void matmul_tri_naive (float * p, int PS, const float * k, int KS, const float * q, int QS, int T, int HS, float alpha) {
281
+ // coordinate system:
282
+ // | - - - - - > j
283
+ // |
284
+ // |
285
+ // v
286
+ // i
287
+ // get coordinates of our block - each thread is responsible for a single 8x8 block.
244
288
int i_base = 128 * blockIdx .x + 8 * threadIdx .x ;
245
289
int j_base = 128 * blockIdx .y + 8 * threadIdx .y ;
246
290
247
- // one more check to skip the upper diagonal in blocks that are on the diagonal.
291
+ // One more check to skip the upper diagonal in blocks that are on the diagonal.
292
+ // Note: we deliberately waste some compute on the jagged diagonal i.e. elements that belong
293
+ // to the upper triangle that should be masked out. This will be ignored due to the causal mask
294
+ // in the reference CPU implementation when used in the `validate_result` function.
295
+ // Alternatively this check should be done in the nested for loop below -> if (i > j) return.
248
296
if (j_base > i_base)
249
297
return ;
250
298
@@ -254,17 +302,17 @@ __device__ void matmul_tri_naive(float* p, int ps, const float* k, int ks, const
254
302
for (int jo = 0 ; jo < 8 ; ++jo) {
255
303
int j = j_base + jo;
256
304
float val = 0 ;
257
- for (int s = 0 ; s < hs ; ++s) {
258
- val += k [i * ks + s] * q [j * qs + s];
305
+ for (int s = 0 ; s < HS ; ++s) {
306
+ val += q [i * QS + s] * k [j * KS + s];
259
307
}
260
- p[i * ps + j] = val * alpha;
308
+ p[i * PS + j] = val * alpha;
261
309
}
262
310
}
263
311
}
264
312
265
313
/* ** Chapter IV - ... **
266
314
*
267
- * Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send there runners of 64 times
315
+ * Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send their runners 64 times
268
316
* to fetch the corresponding shapes. This is terribly inefficient; The runners need a minute or so for each trip,
269
317
* but making a cookie can be done in just a second.
270
318
*
@@ -292,25 +340,25 @@ __device__ void matmul_tri_naive(float* p, int ps, const float* k, int ks, const
292
340
*/
293
341
294
342
// reorganize loops to enable data reuse: 3.5 ms
295
- __device__ void matmul_tri_registers (float * p, int ps , const float * k, int ks , const float * q, int qs , int T, int hs , float alpha) {
343
+ __device__ void matmul_tri_registers (float * p, int PS , const float * k, int KS , const float * q, int QS , int T, int HS , float alpha) {
296
344
int i_base = 128 * blockIdx .x + 8 * threadIdx .x ;
297
345
int j_base = 128 * blockIdx .y + 8 * threadIdx .y ;
298
346
299
347
if (j_base > i_base)
300
348
return ;
301
349
302
350
// shift our pointers to the sub-block this thread is responsible for
303
- k += i_base * ks ;
304
- q += j_base * qs ;
305
- p += i_base * ps + j_base;
351
+ q += i_base * QS ;
352
+ k += j_base * KS ;
353
+ p += i_base * PS + j_base;
306
354
307
355
float vals[8 ][8 ] = {};
308
- for (int s = 0 ; s < hs ; ++s ) {
356
+ for (int hs = 0 ; hs < HS ; ++hs ) {
309
357
float lhs[8 ];
310
358
float rhs[8 ];
311
359
for (int u = 0 ; u < 8 ; ++u) {
312
- lhs[u] = k [u * ks + s ];
313
- rhs[u] = q [u * qs + s ];
360
+ lhs[u] = q [u * QS + hs ];
361
+ rhs[u] = k [u * KS + hs ];
314
362
}
315
363
316
364
for (int i = 0 ; i < 8 ; ++i) {
@@ -322,7 +370,7 @@ __device__ void matmul_tri_registers(float* p, int ps, const float* k, int ks, c
322
370
323
371
for (int i = 0 ; i < 8 ; ++i) {
324
372
for (int j = 0 ; j < 8 ; ++j) {
325
- p[i * ps + j] = vals[i][j] * alpha;
373
+ p[i * PS + j] = vals[i][j] * alpha;
326
374
}
327
375
}
328
376
}
@@ -334,7 +382,7 @@ __device__ void matmul_tri_registers(float* p, int ps, const float* k, int ks, c
334
382
* "Of course", the runner answers, "but they've asked me for an elephant, a lion, a zebra, and a goldfish. These
335
383
* are all over the place, I can't just pick them up at one spot (_strided acccess_).
336
384
* "But the lion is right next to the palm tree. You could bring those two together?", you confirm.
337
- * "Yes", he says, "if the just asked for the different categories at the same time, that would make things
385
+ * "Yes", he says, "if they just asked for the different categories at the same time, that would make things
338
386
* so much easier. See, I have this bucket, I could carry lots of things in one go if I could just scoop them up
339
387
* from the same place (_coalesced access_).
340
388
*
@@ -364,29 +412,30 @@ __device__ void st_vec(float* address, float4 val) {
364
412
}
365
413
366
414
// vector instructions for coalesced memory access: 1.7 ms
367
- __device__ void matmul_tri3 (float * p, int ps, const float * k, int ks, const float * q, int qs, int T, int hs, float alpha) {
415
+ __device__ void matmul_tri3 (float * p, int PS, const float * k, int KS, const float * q, int QS, int T, int HS, float alpha) {
416
+ // Same logic as previous kernel we just load in float4 to improve coalescing
368
417
int i_base = 128 * blockIdx .x + 8 * threadIdx .x ;
369
418
int j_base = 128 * blockIdx .y + 8 * threadIdx .y ;
370
419
371
420
if (j_base > i_base)
372
421
return ;
373
422
374
423
// shift our pointers to the sub-block this thread is responsible for
375
- k += i_base * ks ;
376
- q += j_base * qs ;
377
- p += i_base * ps + j_base;
424
+ q += i_base * QS ;
425
+ k += j_base * KS ;
426
+ p += i_base * PS + j_base;
378
427
379
428
float vals[8 ][8 ] = {};
380
- for (int s = 0 ; s < hs; s += 4 ) {
429
+ for (int hs = 0 ; hs < HS; hs += 4 ) {
381
430
// load in float4 to improve coalescing
382
431
float4 rhs[8 ];
383
432
for (int u = 0 ; u < 8 ; ++u) {
384
- rhs[u] = ld_vec (q + u * qs + s );
433
+ rhs[u] = ld_vec (k + u * KS + hs );
385
434
}
386
435
387
436
for (int i = 0 ; i < 8 ; ++i) {
388
437
// no need to keep lhs around for the i loop, its only reused in the j loop anyway.
389
- float4 lhs = ld_vec (k + i * ks + s );
438
+ float4 lhs = ld_vec (q + i * QS + hs );
390
439
for (int j = 0 ; j < 8 ; ++j) {
391
440
vals[i][j] += lhs.x * rhs[j].x ;
392
441
vals[i][j] += lhs.y * rhs[j].y ;
@@ -403,7 +452,7 @@ __device__ void matmul_tri3(float* p, int ps, const float* k, int ks, const floa
403
452
result.y = vals[i][j + 1 ] * alpha;
404
453
result.z = vals[i][j + 2 ] * alpha;
405
454
result.w = vals[i][j + 3 ] * alpha;
406
- st_vec (p + i * ps + j, result);
455
+ st_vec (p + i * PS + j, result);
407
456
}
408
457
}
409
458
}
@@ -424,7 +473,7 @@ __device__ void matmul_tri3(float* p, int ps, const float* k, int ks, const floa
424
473
* details.]
425
474
*
426
475
*/
427
- __device__ void matmul_tri4 (float * p, int ps , const float * k, int ks , const float * q, int qs , int T, int hs , float alpha) {
476
+ __device__ void matmul_tri4 (float * p, int PS , const float * k, int KS , const float * q, int QS , int T, int HS , float alpha) {
428
477
int i_base = 128 * blockIdx .x + 8 * threadIdx .x ;
429
478
int j_base = 128 * blockIdx .y + 8 * threadIdx .y ;
430
479
@@ -433,29 +482,38 @@ __device__ void matmul_tri4(float* p, int ps, const float* k, int ks, const floa
433
482
if (blockIdx .y > blockIdx .x )
434
483
return ;
435
484
436
- k += 128 * blockIdx .x * ks ;
437
- q += 128 * blockIdx .y * qs ;
485
+ q += 128 * blockIdx .x * QS ;
486
+ k += 128 * blockIdx .y * KS ;
438
487
439
488
__shared__ float lhs_s[128 ][32 ];
440
489
__shared__ float rhs_s[128 ][32 ];
441
490
442
491
float vals[8 ][8 ] = {};
443
- for (int so = 0 ; so < hs ; so += 32 ) {
492
+ for (int so = 0 ; so < HS ; so += 32 ) {
444
493
// Read a large slice of the input, worked on together by all threads.
445
494
// They are organized differently for this part. We want to ensure
446
495
// fully coalesced loads, so we let a single warp handle consecutive
447
496
// addresses, which means we need to combine two threadIdx.y values
448
497
// in one read operation.
449
498
// note: threads may read data here that they don't need themselves.
450
499
// this really is a block-level operation.
500
+ // note2: 16x16 threads (i.e. the block) will, through this for loop, fetch 32 dims from 128 keys and 128 queries
501
+ // i.e. from Q/K, of shape (T, HS) take q[:128, so*32:(so+1)*32] and k[:128, so*32:(so+1)*32]
451
502
__syncthreads ();
452
503
for (int y = threadIdx .y / 2 ; y < 128 ; y += 8 ) {
453
504
int xo = (threadIdx .y % 2 ) * 16 ;
454
- lhs_s[y][threadIdx .x + xo] = k [y * ks + so + threadIdx .x + xo];
455
- rhs_s[y][threadIdx .x + xo] = q [y * qs + so + threadIdx .x + xo];
505
+ lhs_s[y][threadIdx .x + xo] = q [y * QS + so + threadIdx .x + xo];
506
+ rhs_s[y][threadIdx .x + xo] = k [y * KS + so + threadIdx .x + xo];
456
507
}
457
508
__syncthreads ();
458
509
510
+ // Now we compute a partial dot product (only 32 dims) for all combinations of keys and queries (128x128).
511
+ // Each thread does 8x8 of these partial dot products.
512
+ // E.g. thread (0,0) covers queries 0-7 and keys 0-7. More generally first row of threads
513
+ // (0,:) covers queries 0-7 with keys 0-127 and so on.
514
+ // In the next iterations of the outer (`so`) loop we'll be accumulating values to `vals` until we
515
+ // get the full dot product. We then later deposit it into the output matrix for all 8x8 blocks
516
+ // that are below the diagonal.
459
517
for (int si = 0 ; si < 32 ; ++si) {
460
518
float rhs[8 ];
461
519
for (int u = 0 ; u < 8 ; ++u) {
@@ -484,7 +542,7 @@ __device__ void matmul_tri4(float* p, int ps, const float* k, int ks, const floa
484
542
result.y = vals[ii][ji + 1 ] * alpha;
485
543
result.z = vals[ii][ji + 2 ] * alpha;
486
544
result.w = vals[ii][ji + 3 ] * alpha;
487
- st_vec (p + i * ps + j, result);
545
+ st_vec (p + i * PS + j, result);
488
546
}
489
547
}
490
548
}
0 commit comments