@@ -1723,7 +1723,7 @@ typedef struct {
1723
1723
floatX* fch_gelu; // (L, B, T, 4*C)
1724
1724
floatX* fcproj; // (L, B, T, C)
1725
1725
floatX* residual3; // (L, B, T, C)
1726
- floatX* lnf; // (B, T, C)
1726
+ floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms
1727
1727
floatX* lnf_mean; // (B, T)
1728
1728
floatX* lnf_rstd; // (B, T)
1729
1729
floatX* losses; // (B, T)
@@ -1744,7 +1744,7 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config
1744
1744
size_t C = config.channels ;
1745
1745
act_sizes[0 ] = B * T * C; // encoded
1746
1746
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
1747
- act_sizes[1 ] = (recompute < 2 ) ? L * B * T * C : B * T * C ; // ln1
1747
+ act_sizes[1 ] = (recompute < 2 ) ? L * B * T * C : 0 ; // ln1
1748
1748
act_sizes[2 ] = L * B * T; // ln1_mean
1749
1749
act_sizes[3 ] = L * B * T; // ln1_rstd
1750
1750
act_sizes[4 ] = L * B * T * C; // atty
@@ -1757,7 +1757,7 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config
1757
1757
act_sizes[6 ] = L * B * T * C; // attproj
1758
1758
act_sizes[7 ] = L * B * T * C; // residual2
1759
1759
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
1760
- act_sizes[8 ] = (recompute < 2 ) ? L * B * T * C : B * T * C ; // ln2
1760
+ act_sizes[8 ] = (recompute < 2 ) ? L * B * T * C : 0 ; // ln2
1761
1761
act_sizes[9 ] = L * B * T; // ln2_mean
1762
1762
act_sizes[10 ] = L * B * T; // ln2_rstd
1763
1763
act_sizes[11 ] = L * B * T * 4 *C; // fch
@@ -1810,8 +1810,13 @@ void* malloc_and_point(floatX** targets[], const size_t* act_sizes, size_t n) {
1810
1810
cudaCheck (cudaMalloc ((void **)&acts_memory, num_activations * sizeof (floatX)));
1811
1811
char * acts_memory_iterator = (char *)acts_memory;
1812
1812
for (size_t i = 0 ; i < n; i++) {
1813
- *(targets[i]) = (floatX*)acts_memory_iterator;
1814
- acts_memory_iterator += act_sizes[i] * sizeof (floatX);
1813
+ // extra protection so we don't accidentally use an empty buffer
1814
+ if (act_sizes[i] == 0 ) {
1815
+ *(targets[i]) = NULL ;
1816
+ }else {
1817
+ *(targets[i]) = (floatX*) acts_memory_iterator;
1818
+ acts_memory_iterator += act_sizes[i] * sizeof (floatX);
1819
+ }
1815
1820
}
1816
1821
return acts_memory;
1817
1822
}
@@ -2177,12 +2182,12 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T, in
2177
2182
floatX* l_fcprojb = params.fcprojb + l * C;
2178
2183
2179
2184
// get the pointers of the activations for this layer
2180
- floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + l * B * T * C : acts.ln1 ;
2185
+ floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + l * B * T * C : acts.lnf ;
2181
2186
floatX* l_qkvr = acts.qkvr + l * B * T * 3 *C;
2182
2187
floatX* l_atty = acts.atty + l * B * T * C;
2183
2188
floatX* l_attproj = acts.attproj + l * B * T * C;
2184
2189
floatX* l_residual2 = acts.residual2 + l * B * T * C;
2185
- floatX* l_ln2 = (model->recompute < 2 ) ? acts.ln2 + l * B * T * C : acts.ln2 ;
2190
+ floatX* l_ln2 = (model->recompute < 2 ) ? acts.ln2 + l * B * T * C : acts.lnf ;
2186
2191
floatX* l_ln2_mean = acts.ln2_mean + l * B * T;
2187
2192
floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T;
2188
2193
floatX* l_fch = acts.fch + l * B * T * 4 *C;
@@ -2214,7 +2219,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T, in
2214
2219
2215
2220
// OK, fusion across blocks.
2216
2221
if (l+1 != L) {
2217
- floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + (l + 1 ) * B * T * C : acts.ln1 ;
2222
+ floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + (l + 1 ) * B * T * C : acts.lnf ;
2218
2223
floatX* l_ln1_mean = acts.ln1_mean + (l + 1 ) * B * T;
2219
2224
floatX* l_ln1_rstd = acts.ln1_rstd + (l + 1 ) * B * T;
2220
2225
const floatX* l_ln1w = params.ln1w + (l + 1 ) * C;
@@ -2324,6 +2329,10 @@ void gpt2_backward(GPT2 *model, int* inputs) {
2324
2329
floatX* dresidual = (floatX*)grads_acts.residual3 ; // the main buffer holding the gradient in the backward pass
2325
2330
layernorm_backward (dresidual, grads.lnfw , grads.lnfb , scratchF, grads_acts.bt4c , residual, params.lnfw , acts.lnf_mean , acts.lnf_rstd , B, T, C);
2326
2331
2332
+ // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic
2333
+ // scratch for backward computations
2334
+ floatX* dl_btc = residual;
2335
+
2327
2336
// now backward all the layers
2328
2337
for (int l = L-1 ; l >= 0 ; l--) {
2329
2338
NvtxRange layer_range (" Layer" , l);
@@ -2353,13 +2362,13 @@ void gpt2_backward(GPT2 *model, int* inputs) {
2353
2362
floatX* dl_fcprojw = grads.fcprojw + l * C * 4 *C;
2354
2363
floatX* dl_fcprojb = grads.fcprojb + l * C;
2355
2364
// get the pointers of the activations for this layer
2356
- floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + l * B * T * C : acts.ln1 ;
2365
+ floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + l * B * T * C : acts.lnf ;
2357
2366
floatX* l_ln1_mean = acts.ln1_mean + l * B * T;
2358
2367
floatX* l_ln1_rstd = acts.ln1_rstd + l * B * T;
2359
2368
floatX* l_qkvr = acts.qkvr + l * B * T * 3 *C;
2360
2369
floatX* l_atty = acts.atty + l * B * T * C;
2361
2370
floatX* l_residual2 = acts.residual2 + l * B * T * C;
2362
- floatX* l_ln2 = (model->recompute < 2 ) ? acts.ln2 + l * B * T * C : acts.ln2 ;
2371
+ floatX* l_ln2 = (model->recompute < 2 ) ? acts.ln2 + l * B * T * C : acts.lnf ;
2363
2372
floatX* l_ln2_mean = acts.ln2_mean + l * B * T;
2364
2373
floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T;
2365
2374
floatX* l_fch = acts.fch + l * B * T * 4 *C;
@@ -2368,9 +2377,6 @@ void gpt2_backward(GPT2 *model, int* inputs) {
2368
2377
// notice that there is no l *, because we just have a single copy, and keep
2369
2378
// re-using this memory in every Transformer block as we calculate backward pass
2370
2379
2371
- // we need a B x T x C buffer; thankfully, the forward activation for lnf isn't needed anymore,
2372
- // so we can co-opt it here.
2373
- floatX* dl_btc = (floatX*)acts.lnf ;
2374
2380
floatX* dl_bt4c = (floatX*)grads_acts.bt4c ;
2375
2381
2376
2382
// start the backward pass for this layer
0 commit comments