@@ -10,11 +10,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
10
10
#include < string_view>
11
11
#include < sys/stat.h>
12
12
#include < sys/types.h>
13
- #ifdef MULTI_GPU
14
- #include < mpi.h>
15
- #include < nccl.h>
16
- #endif
17
- // our own utilities
13
+ // ----------- CPU utilities -----------
18
14
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
19
15
// defines: create_dir_if_not_exists, find_max_step
20
16
#include " llmc/utils.h"
@@ -31,6 +27,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
31
27
#include " llmc/logger.h"
32
28
// defines: get_flops_promised
33
29
#include " llmc/mfu.h"
30
+ // ----------- GPU utilities -----------
34
31
// defines:
35
32
// WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE
36
33
// NVTX_RANGE_FN
@@ -42,6 +39,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
42
39
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
43
40
// defines: cublas_compute, cublaslt_handle, cublas_handle
44
41
#include " llmc/cublas_common.h"
42
+ // ----------- Layer implementations in CUDA -----------
45
43
// defines: encoder_forward, encoder_backward
46
44
#include " llmc/encoder.cuh"
47
45
// defines: gelu_forward, gelu_backward_inplace
@@ -53,6 +51,13 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
53
51
// defines: attention_forward, attention_backward
54
52
#include " llmc/attention.cuh"
55
53
#endif
54
+ // defines: fused_classifier
55
+ #include " llmc/fused_classifier.cuh"
56
+ // ----------- Multi-GPU support -----------
57
+ #ifdef MULTI_GPU
58
+ #include < mpi.h>
59
+ #include < nccl.h>
60
+ #endif
56
61
57
62
// ----------------------------------------------------------------------------
58
63
// global var containing information about the GPU this process is running on
@@ -730,121 +735,6 @@ __global__ void global_norm_squared_kernel(float* out, const T* data, size_t cou
730
735
}
731
736
}
732
737
733
- struct SoftmaxParams {
734
- float Scale;
735
- float Offset;
736
- };
737
-
738
- __device__ SoftmaxParams prepare_softmax_blockwide3 (int64_t idx, const floatX* inp, int V, int P) {
739
- // same but not float4
740
- // one row of inp, i.e. inp[idx, :] of shape (V,)
741
-
742
- const floatX* x = inp + idx * P;
743
- float thread_maxval = -INFINITY;
744
- float thread_sumval = 0 .0f ;
745
- int i = (V+x128::size-1 )/x128::size + threadIdx .x - blockDim .x ;
746
-
747
- // special-case loop to handle the unaligned elements at the end of the array
748
- // this lets us skip the bounds check in the main loop below, which improves performance
749
- while ((i+1 )*x128::size > V) {
750
- for (int k = 0 ; k < x128::size; ++k) {
751
- if (i*x128::size+k >= V) {
752
- break ; // bounds checking against real V (rather than padded P)
753
- }
754
- float v = (float )x[i*x128::size+k];
755
- float old_maxval = thread_maxval;
756
- thread_maxval = fmaxf (thread_maxval, v);
757
- thread_sumval *= expf ((old_maxval - thread_maxval));
758
- thread_sumval += expf (v - thread_maxval);
759
- }
760
- i -= blockDim .x ;
761
- }
762
-
763
- // main loop for the bulk of the iterations (no bounds checking required!)
764
- for (; i >= 0 ; i -= blockDim .x ) {
765
- x128 packed_x = load128 (x + i * x128::size); // load and keep in cache until fused_classifier loop
766
- for (int k = 0 ; k < x128::size; ++k) {
767
- float v = (float )packed_x[k];
768
- float old_maxval = thread_maxval;
769
- thread_maxval = fmaxf (thread_maxval, v);
770
- thread_sumval *= expf ((old_maxval - thread_maxval));
771
- thread_sumval += expf (v - thread_maxval);
772
- }
773
- }
774
-
775
- // Block Max Reduction -> Maths -> Block Sum Reduction
776
- float block_maxval = blockReduce<warpReduceMax>(thread_maxval, false , -INFINITY);
777
- thread_sumval *= expf (thread_maxval - block_maxval);
778
- float block_sumval = blockReduce<warpReduceSum>(thread_sumval);
779
-
780
- // return the softmax parameters
781
- return SoftmaxParams{1 .f / block_sumval, block_maxval};
782
- }
783
-
784
- // will _update_ logits to logit gradients
785
- // uses template to decide whether to write logits and probs
786
- // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts
787
- template <bool WriteLogits = true , bool WriteProbs = false >
788
- __global__ void __launch_bounds__ (1024 , MAX_1024_THREADS_BLOCKS)
789
- fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs,
790
- const float dloss, const int * targets,
791
- int B, int T, int V, int P) {
792
- // note: idx is small enough that it easily fits into 32 bit;
793
- // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P)
794
- // are done is 64 bit
795
- int64_t idx = gridDim .x - (blockIdx .x +1 ); // reverse order for cache hits on matmul data
796
- int ix = targets[idx];
797
-
798
- // softmax (reading B * T * V, same logits read again below, hopefully still in cache)
799
- SoftmaxParams sp = prepare_softmax_blockwide3 (idx, logits, V, P);
800
-
801
- // calculate the probability needed for the loss and update (single-threaded)
802
- if (threadIdx .x == 0 ) {
803
- float prob = expf ((float )logits[idx * P + ix] - sp.Offset ) * sp.Scale ;
804
- losses[idx] = (floatX)(-logf (prob));
805
- }
806
-
807
- // calculate the gradients directly, saves bandwidth from probs during training
808
- // but also supports writing probs for inference-only and debugging
809
- const floatX* logits_vec = logits + idx * P;
810
- for (int i = threadIdx .x ; i < V/x128::size; i += blockDim .x ) {
811
- // this is the 2nd read of logits after the one in prepare_softmax2
812
- // it will be overwritten by the logits gradients which is when we reduce cache persistence
813
- x128 packed_logits_vec = load128 (logits_vec + i * x128::size); // rely on cs of store128cs
814
- x128 packed_probs;
815
- for (int k = 0 ; k < x128::size; ++k) {
816
- int element = i*x128::size + k;
817
- float prob = expf ((float )packed_logits_vec[k] - sp.Offset ) * sp.Scale ;
818
- packed_probs[k] = (floatX)prob;
819
- float indicator = (element == ix) ? 1 .0f : 0 .0f ;
820
- packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);
821
- }
822
- if (WriteLogits){
823
- // reduce cache persistence for the overwritten logits
824
- // to maximise probability that logits remain in cache between prepare_softmax and here
825
- store128cs (logits + idx * P + i * x128::size, packed_logits_vec);
826
- }
827
- if (WriteProbs) {
828
- store128 (probs + idx * P + i * x128::size, packed_probs);
829
- }
830
- }
831
-
832
- // handle remaining elements after the last multiple of x128::size
833
- // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements
834
- int unaligned_start = V & ~(x128::size - 1 ); // round down to multiple of x128::size
835
- for (int i = threadIdx .x + unaligned_start; i < V; i++) {
836
- float prob = expf ((float )logits_vec[i] - sp.Offset ) * sp.Scale ;
837
- float indicator = (i == ix) ? 1 .0f : 0 .0f ;
838
- float dlogit = (prob - indicator) * dloss;
839
- if (WriteLogits){
840
- __stcs (logits + idx * P + i, (floatX)dlogit);
841
- }
842
- if (WriteProbs) {
843
- probs[idx * P + i] = (floatX)prob;
844
- }
845
- }
846
- }
847
-
848
738
// device functions and the kernel to cast data between types
849
739
template <typename Td, typename Ts>
850
740
__device__ Td cast_value (Ts val);
@@ -1050,19 +940,6 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr
1050
940
cudaCheck (cudaGetLastError ());
1051
941
}
1052
942
1053
- // replaces logits with logit gradients
1054
- template <typename Type>
1055
- void fused_classifier (Type* logits, Type* losses,
1056
- const float dloss, const int * targets,
1057
- int B, int T, int V, int P) {
1058
- NVTX_RANGE_FN ();
1059
- const int block_size = 1024 ;
1060
- const int N = B * T;
1061
- const int grid_size = N;
1062
- fused_classifier_kernel5<<<grid_size, block_size>>> (logits, losses, (floatX*)NULL , dloss, targets, B, T, V, P);
1063
- cudaCheck (cudaGetLastError ());
1064
- }
1065
-
1066
943
template <typename T>
1067
944
void global_norm_squared (float * out, const T* values, size_t count) {
1068
945
const int block_size = 512 ;
0 commit comments