Skip to content

Commit 299d375

Browse files
authored
Merge pull request #540 from karpathy/move/fused_classifier
move fused classifier
2 parents dd4191f + 90a7bfe commit 299d375

File tree

2 files changed

+153
-133
lines changed

2 files changed

+153
-133
lines changed

llmc/fused_classifier.cuh

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
Fused Classifier:
3+
- Forwards the Cross Entropy Loss
4+
- Never materializes the full normalized logits, only at the target label
5+
- (fusion) Also kicks off the backward pass, because everything is already loaded
6+
*/
7+
// llmc internal imports
8+
#include "cuda_common.h"
9+
#include "cuda_utils.cuh"
10+
11+
// ----------------------------------------------------------------------------
12+
// CUDA kernels
13+
14+
struct SoftmaxParams {
15+
float Scale;
16+
float Offset;
17+
};
18+
19+
__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) {
20+
// same but not float4
21+
// one row of inp, i.e. inp[idx, :] of shape (V,)
22+
23+
const floatX* x = inp + idx * P;
24+
float thread_maxval = -INFINITY;
25+
float thread_sumval = 0.0f;
26+
int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x;
27+
28+
// special-case loop to handle the unaligned elements at the end of the array
29+
// this lets us skip the bounds check in the main loop below, which improves performance
30+
while ((i+1)*x128::size > V) {
31+
for(int k = 0; k < x128::size; ++k) {
32+
if (i*x128::size+k >= V) {
33+
break; // bounds checking against real V (rather than padded P)
34+
}
35+
float v = (float)x[i*x128::size+k];
36+
float old_maxval = thread_maxval;
37+
thread_maxval = fmaxf(thread_maxval, v);
38+
thread_sumval *= expf((old_maxval - thread_maxval));
39+
thread_sumval += expf(v - thread_maxval);
40+
}
41+
i -= blockDim.x;
42+
}
43+
44+
// main loop for the bulk of the iterations (no bounds checking required!)
45+
for (; i >= 0; i -= blockDim.x) {
46+
x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop
47+
for(int k = 0; k < x128::size; ++k) {
48+
float v = (float)packed_x[k];
49+
float old_maxval = thread_maxval;
50+
thread_maxval = fmaxf(thread_maxval, v);
51+
thread_sumval *= expf((old_maxval - thread_maxval));
52+
thread_sumval += expf(v - thread_maxval);
53+
}
54+
}
55+
56+
// Block Max Reduction -> Maths -> Block Sum Reduction
57+
float block_maxval = blockReduce<warpReduceMax>(thread_maxval, false, -INFINITY);
58+
thread_sumval *= expf(thread_maxval - block_maxval);
59+
float block_sumval = blockReduce<warpReduceSum>(thread_sumval);
60+
61+
// return the softmax parameters
62+
return SoftmaxParams{1.f / block_sumval, block_maxval};
63+
}
64+
65+
// will _update_ logits to logit gradients
66+
// uses template to decide whether to write logits and probs
67+
// split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts
68+
template <bool WriteLogits = true, bool WriteProbs = false>
69+
__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
70+
fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs,
71+
const float dloss, const int* targets,
72+
int B, int T, int V, int P) {
73+
// note: idx is small enough that it easily fits into 32 bit;
74+
// by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P)
75+
// are done is 64 bit
76+
int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data
77+
int ix = targets[idx];
78+
79+
// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
80+
SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P);
81+
82+
// calculate the probability needed for the loss and update (single-threaded)
83+
if(threadIdx.x == 0) {
84+
float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;
85+
losses[idx] = (floatX)(-logf(prob));
86+
}
87+
88+
// calculate the gradients directly, saves bandwidth from probs during training
89+
// but also supports writing probs for inference-only and debugging
90+
const floatX* logits_vec = logits + idx * P;
91+
for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) {
92+
// this is the 2nd read of logits after the one in prepare_softmax2
93+
// it will be overwritten by the logits gradients which is when we reduce cache persistence
94+
x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs
95+
x128 packed_probs;
96+
for(int k = 0; k < x128::size; ++k) {
97+
int element = i*x128::size + k;
98+
float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale;
99+
packed_probs[k] = (floatX)prob;
100+
float indicator = (element == ix) ? 1.0f : 0.0f;
101+
packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);
102+
}
103+
if (WriteLogits){
104+
// reduce cache persistence for the overwritten logits
105+
// to maximise probability that logits remain in cache between prepare_softmax and here
106+
store128cs(logits + idx * P + i * x128::size, packed_logits_vec);
107+
}
108+
if (WriteProbs) {
109+
store128(probs + idx * P + i * x128::size, packed_probs);
110+
}
111+
}
112+
113+
// handle remaining elements after the last multiple of x128::size
114+
// e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements
115+
int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size
116+
for (int i = threadIdx.x + unaligned_start; i < V; i++) {
117+
float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale;
118+
float indicator = (i == ix) ? 1.0f : 0.0f;
119+
float dlogit = (prob - indicator) * dloss;
120+
if (WriteLogits){
121+
__stcs(logits + idx * P + i, (floatX)dlogit);
122+
}
123+
if (WriteProbs) {
124+
probs[idx * P + i] = (floatX)prob;
125+
}
126+
}
127+
}
128+
129+
// ----------------------------------------------------------------------------
130+
// kernel launchers
131+
132+
// replaces logits with logit gradients
133+
template <typename Type>
134+
void fused_classifier(Type* logits, Type* losses,
135+
const float dloss, const int* targets,
136+
int B, int T, int V, int P) {
137+
NVTX_RANGE_FN();
138+
const int block_size = 1024;
139+
const int N = B * T;
140+
const int grid_size = N;
141+
fused_classifier_kernel5<<<grid_size, block_size>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P);
142+
cudaCheck(cudaGetLastError());
143+
}

train_gpt2.cu

+10-133
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
1010
#include <string_view>
1111
#include <sys/stat.h>
1212
#include <sys/types.h>
13-
#ifdef MULTI_GPU
14-
#include <mpi.h>
15-
#include <nccl.h>
16-
#endif
17-
// our own utilities
13+
// ----------- CPU utilities -----------
1814
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
1915
// defines: create_dir_if_not_exists, find_max_step
2016
#include "llmc/utils.h"
@@ -31,6 +27,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
3127
#include "llmc/logger.h"
3228
// defines: get_flops_promised
3329
#include "llmc/mfu.h"
30+
// ----------- GPU utilities -----------
3431
// defines:
3532
// WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE
3633
// NVTX_RANGE_FN
@@ -42,6 +39,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
4239
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
4340
// defines: cublas_compute, cublaslt_handle, cublas_handle
4441
#include "llmc/cublas_common.h"
42+
// ----------- Layer implementations in CUDA -----------
4543
// defines: encoder_forward, encoder_backward
4644
#include "llmc/encoder.cuh"
4745
// defines: gelu_forward, gelu_backward_inplace
@@ -53,6 +51,13 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
5351
// defines: attention_forward, attention_backward
5452
#include "llmc/attention.cuh"
5553
#endif
54+
// defines: fused_classifier
55+
#include "llmc/fused_classifier.cuh"
56+
// ----------- Multi-GPU support -----------
57+
#ifdef MULTI_GPU
58+
#include <mpi.h>
59+
#include <nccl.h>
60+
#endif
5661

5762
// ----------------------------------------------------------------------------
5863
// global var containing information about the GPU this process is running on
@@ -730,121 +735,6 @@ __global__ void global_norm_squared_kernel(float* out, const T* data, size_t cou
730735
}
731736
}
732737

733-
struct SoftmaxParams {
734-
float Scale;
735-
float Offset;
736-
};
737-
738-
__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) {
739-
// same but not float4
740-
// one row of inp, i.e. inp[idx, :] of shape (V,)
741-
742-
const floatX* x = inp + idx * P;
743-
float thread_maxval = -INFINITY;
744-
float thread_sumval = 0.0f;
745-
int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x;
746-
747-
// special-case loop to handle the unaligned elements at the end of the array
748-
// this lets us skip the bounds check in the main loop below, which improves performance
749-
while ((i+1)*x128::size > V) {
750-
for(int k = 0; k < x128::size; ++k) {
751-
if (i*x128::size+k >= V) {
752-
break; // bounds checking against real V (rather than padded P)
753-
}
754-
float v = (float)x[i*x128::size+k];
755-
float old_maxval = thread_maxval;
756-
thread_maxval = fmaxf(thread_maxval, v);
757-
thread_sumval *= expf((old_maxval - thread_maxval));
758-
thread_sumval += expf(v - thread_maxval);
759-
}
760-
i -= blockDim.x;
761-
}
762-
763-
// main loop for the bulk of the iterations (no bounds checking required!)
764-
for (; i >= 0; i -= blockDim.x) {
765-
x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop
766-
for(int k = 0; k < x128::size; ++k) {
767-
float v = (float)packed_x[k];
768-
float old_maxval = thread_maxval;
769-
thread_maxval = fmaxf(thread_maxval, v);
770-
thread_sumval *= expf((old_maxval - thread_maxval));
771-
thread_sumval += expf(v - thread_maxval);
772-
}
773-
}
774-
775-
// Block Max Reduction -> Maths -> Block Sum Reduction
776-
float block_maxval = blockReduce<warpReduceMax>(thread_maxval, false, -INFINITY);
777-
thread_sumval *= expf(thread_maxval - block_maxval);
778-
float block_sumval = blockReduce<warpReduceSum>(thread_sumval);
779-
780-
// return the softmax parameters
781-
return SoftmaxParams{1.f / block_sumval, block_maxval};
782-
}
783-
784-
// will _update_ logits to logit gradients
785-
// uses template to decide whether to write logits and probs
786-
// split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts
787-
template <bool WriteLogits = true, bool WriteProbs = false>
788-
__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
789-
fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs,
790-
const float dloss, const int* targets,
791-
int B, int T, int V, int P) {
792-
// note: idx is small enough that it easily fits into 32 bit;
793-
// by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P)
794-
// are done is 64 bit
795-
int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data
796-
int ix = targets[idx];
797-
798-
// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
799-
SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P);
800-
801-
// calculate the probability needed for the loss and update (single-threaded)
802-
if(threadIdx.x == 0) {
803-
float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;
804-
losses[idx] = (floatX)(-logf(prob));
805-
}
806-
807-
// calculate the gradients directly, saves bandwidth from probs during training
808-
// but also supports writing probs for inference-only and debugging
809-
const floatX* logits_vec = logits + idx * P;
810-
for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) {
811-
// this is the 2nd read of logits after the one in prepare_softmax2
812-
// it will be overwritten by the logits gradients which is when we reduce cache persistence
813-
x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs
814-
x128 packed_probs;
815-
for(int k = 0; k < x128::size; ++k) {
816-
int element = i*x128::size + k;
817-
float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale;
818-
packed_probs[k] = (floatX)prob;
819-
float indicator = (element == ix) ? 1.0f : 0.0f;
820-
packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);
821-
}
822-
if (WriteLogits){
823-
// reduce cache persistence for the overwritten logits
824-
// to maximise probability that logits remain in cache between prepare_softmax and here
825-
store128cs(logits + idx * P + i * x128::size, packed_logits_vec);
826-
}
827-
if (WriteProbs) {
828-
store128(probs + idx * P + i * x128::size, packed_probs);
829-
}
830-
}
831-
832-
// handle remaining elements after the last multiple of x128::size
833-
// e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements
834-
int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size
835-
for (int i = threadIdx.x + unaligned_start; i < V; i++) {
836-
float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale;
837-
float indicator = (i == ix) ? 1.0f : 0.0f;
838-
float dlogit = (prob - indicator) * dloss;
839-
if (WriteLogits){
840-
__stcs(logits + idx * P + i, (floatX)dlogit);
841-
}
842-
if (WriteProbs) {
843-
probs[idx * P + i] = (floatX)prob;
844-
}
845-
}
846-
}
847-
848738
// device functions and the kernel to cast data between types
849739
template<typename Td, typename Ts>
850740
__device__ Td cast_value(Ts val);
@@ -1050,19 +940,6 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr
1050940
cudaCheck(cudaGetLastError());
1051941
}
1052942

1053-
// replaces logits with logit gradients
1054-
template <typename Type>
1055-
void fused_classifier(Type* logits, Type* losses,
1056-
const float dloss, const int* targets,
1057-
int B, int T, int V, int P) {
1058-
NVTX_RANGE_FN();
1059-
const int block_size = 1024;
1060-
const int N = B * T;
1061-
const int grid_size = N;
1062-
fused_classifier_kernel5<<<grid_size, block_size>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P);
1063-
cudaCheck(cudaGetLastError());
1064-
}
1065-
1066943
template<typename T>
1067944
void global_norm_squared(float* out, const T* values, size_t count) {
1068945
const int block_size = 512;

0 commit comments

Comments
 (0)