Skip to content

Commit d17893b

Browse files
authored
Merge pull request #532 from ngc92/ln-buffers
more efficient use of memory buffers for LN recomputation
2 parents ac93145 + 0d16b51 commit d17893b

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

train_gpt2.cu

+19-13
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,7 @@ typedef struct {
17231723
floatX* fch_gelu; // (L, B, T, 4*C)
17241724
floatX* fcproj; // (L, B, T, C)
17251725
floatX* residual3; // (L, B, T, C)
1726-
floatX* lnf; // (B, T, C)
1726+
floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms
17271727
floatX* lnf_mean; // (B, T)
17281728
floatX* lnf_rstd; // (B, T)
17291729
floatX* losses; // (B, T)
@@ -1744,7 +1744,7 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config
17441744
size_t C = config.channels;
17451745
act_sizes[0] = B * T * C; // encoded
17461746
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
1747-
act_sizes[1] = (recompute < 2) ? L * B * T * C : B * T * C; // ln1
1747+
act_sizes[1] = (recompute < 2) ? L * B * T * C : 0; // ln1
17481748
act_sizes[2] = L * B * T; // ln1_mean
17491749
act_sizes[3] = L * B * T; // ln1_rstd
17501750
act_sizes[4] = L * B * T * C; // atty
@@ -1757,7 +1757,7 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config
17571757
act_sizes[6] = L * B * T * C; // attproj
17581758
act_sizes[7] = L * B * T * C; // residual2
17591759
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
1760-
act_sizes[8] = (recompute < 2) ? L * B * T * C : B * T * C; // ln2
1760+
act_sizes[8] = (recompute < 2) ? L * B * T * C : 0; // ln2
17611761
act_sizes[9] = L * B * T; // ln2_mean
17621762
act_sizes[10] = L * B * T; // ln2_rstd
17631763
act_sizes[11] = L * B * T * 4*C; // fch
@@ -1810,8 +1810,13 @@ void* malloc_and_point(floatX** targets[], const size_t* act_sizes, size_t n) {
18101810
cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(floatX)));
18111811
char* acts_memory_iterator = (char*)acts_memory;
18121812
for (size_t i = 0; i < n; i++) {
1813-
*(targets[i]) = (floatX*)acts_memory_iterator;
1814-
acts_memory_iterator += act_sizes[i] * sizeof(floatX);
1813+
// extra protection so we don't accidentally use an empty buffer
1814+
if(act_sizes[i] == 0) {
1815+
*(targets[i]) = NULL;
1816+
}else {
1817+
*(targets[i]) = (floatX*) acts_memory_iterator;
1818+
acts_memory_iterator += act_sizes[i] * sizeof(floatX);
1819+
}
18151820
}
18161821
return acts_memory;
18171822
}
@@ -2177,12 +2182,12 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T, in
21772182
floatX* l_fcprojb = params.fcprojb + l * C;
21782183

21792184
// get the pointers of the activations for this layer
2180-
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.ln1;
2185+
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
21812186
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
21822187
floatX* l_atty = acts.atty + l * B * T * C;
21832188
floatX* l_attproj = acts.attproj + l * B * T * C;
21842189
floatX* l_residual2 = acts.residual2 + l * B * T * C;
2185-
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.ln2;
2190+
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
21862191
floatX* l_ln2_mean = acts.ln2_mean + l * B * T;
21872192
floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T;
21882193
floatX* l_fch = acts.fch + l * B * T * 4*C;
@@ -2214,7 +2219,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T, in
22142219

22152220
// OK, fusion across blocks.
22162221
if(l+1 != L) {
2217-
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.ln1;
2222+
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf;
22182223
floatX* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T;
22192224
floatX* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T;
22202225
const floatX* l_ln1w = params.ln1w + (l + 1) * C;
@@ -2324,6 +2329,10 @@ void gpt2_backward(GPT2 *model, int* inputs) {
23242329
floatX* dresidual = (floatX*)grads_acts.residual3; // the main buffer holding the gradient in the backward pass
23252330
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, grads_acts.bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);
23262331

2332+
// from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic
2333+
// scratch for backward computations
2334+
floatX* dl_btc = residual;
2335+
23272336
// now backward all the layers
23282337
for (int l = L-1; l >= 0; l--) {
23292338
NvtxRange layer_range("Layer", l);
@@ -2353,13 +2362,13 @@ void gpt2_backward(GPT2 *model, int* inputs) {
23532362
floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
23542363
floatX* dl_fcprojb = grads.fcprojb + l * C;
23552364
// get the pointers of the activations for this layer
2356-
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.ln1;
2365+
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
23572366
floatX* l_ln1_mean = acts.ln1_mean + l * B * T;
23582367
floatX* l_ln1_rstd = acts.ln1_rstd + l * B * T;
23592368
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
23602369
floatX* l_atty = acts.atty + l * B * T * C;
23612370
floatX* l_residual2 = acts.residual2 + l * B * T * C;
2362-
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.ln2;
2371+
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
23632372
floatX* l_ln2_mean = acts.ln2_mean + l * B * T;
23642373
floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T;
23652374
floatX* l_fch = acts.fch + l * B * T * 4*C;
@@ -2368,9 +2377,6 @@ void gpt2_backward(GPT2 *model, int* inputs) {
23682377
// notice that there is no l *, because we just have a single copy, and keep
23692378
// re-using this memory in every Transformer block as we calculate backward pass
23702379

2371-
// we need a B x T x C buffer; thankfully, the forward activation for lnf isn't needed anymore,
2372-
// so we can co-opt it here.
2373-
floatX* dl_btc = (floatX*)acts.lnf;
23742380
floatX* dl_bt4c = (floatX*)grads_acts.bt4c;
23752381

23762382
// start the backward pass for this layer

0 commit comments

Comments
 (0)