Skip to content

Commit

Permalink
i can backward through MLP block. Attention block is next
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Sep 27, 2024
1 parent 2c4b3cc commit 1b54612
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 45 deletions.
46 changes: 46 additions & 0 deletions llmc/swiglu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,41 @@ __global__ void swiglu_forward_kernel2(floatX* out, const floatX* inp, int B, in
out[idx] = (floatX)((x1 * x2) / (1.0f + expf(-x2)));
}

__global__ void swiglu_backward_kernel1(floatX* dinp, const floatX* dout, const floatX* inp, int B, int T, int C) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
const floatX* dout_ptr = dout + idx;
// b,t,c in the output
int b = idx / (T * C);
int t = (idx / C) % T;
int c = idx % C;
// coords in input
int C2 = C * 2;
const floatX* inp1_ptr = inp + (b * T * C2 + t * C2 + c);
const floatX* inp2_ptr = inp1_ptr + C;
floatX* dinp1_ptr = dinp + (b * T * C2 + t * C2 + c);
floatX* dinp2_ptr = dinp1_ptr + C;
// backward
x128 dinp1;
x128 dinp2;
x128 packed_dout = load128cs(dout_ptr);
x128 packed_inp1 = load128cs(inp1_ptr); // fc1
x128 packed_inp2 = load128cs(inp2_ptr); // fc2
for(int k = 0; k < packed_inp1.size; ++k) {
float x1 = (float)packed_inp1[k];
float x2 = (float)packed_inp2[k];
float dout = (float)packed_dout[k];

float sx2 = 1.0f / (1.0f + expf(-x2)); // sigmoid of x2
float dx1 = dout * x2 * sx2;
float dx2 = dout * x1 * sx2 * (1.0f + x2 * (1.0f - sx2));

dinp1[k] = (floatX)dx1;
dinp2[k] = (floatX)dx2;
}
store128(dinp1_ptr, dinp1);
store128(dinp2_ptr, dinp2);
}

// ----------------------------------------------------------------------------
// kernel launchers

Expand All @@ -84,3 +119,14 @@ void swiglu_forward_naive(floatX* out, const floatX* inp, int B, int T, int C, c
swiglu_forward_kernel2<<<grid_size, block_size, 0, stream>>>(out, inp, B, T, C);
cudaCheck(cudaGetLastError());
}

void swiglu_backward(floatX* dinp, const floatX* dout, const floatX* inp, int B, int T, int C, cudaStream_t stream) {
// input is (B, T, 2C), output is (B, T, C)
// we have that inp[b, t, :] = [fc1, fc2] (i.e. they are concatenated in each C-fiber)
NVTX_RANGE_FN();
const int block_size = 128;
assert((B*T*C) % (block_size * x128::size) == 0);
const int grid_size = CEIL_DIV(B*T*C, block_size * x128::size);
swiglu_backward_kernel1<<<grid_size, block_size, 0, stream>>>(dinp, dout, inp, B, T, C);
cudaCheck(cudaGetLastError());
}
92 changes: 53 additions & 39 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/repkv.cuh"
// defines: precompute_freqs_cis, rope_forward
#include "llmc/rope.cuh"
// defines: swiglu_forward
// defines: swiglu_forward, swiglu_backward
#include "llmc/swiglu.cuh"
// ----------- Multi-GPU support -----------
// defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo
Expand Down Expand Up @@ -289,8 +289,8 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T);
tensors[16] = TENSOR_SPEC(data->losses, B * T);
tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * qkv_channels);
tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(NH*T, Vp)));
tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C);
tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp))));
tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels);
tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C);
}

Expand Down Expand Up @@ -786,6 +786,16 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
const size_t L = model->config.num_layers;
const size_t NH = model->config.num_heads;
const size_t C = model->config.channels;
const size_t n_head = model->config.num_heads;
const size_t n_kv_head = model->config.num_kv_heads;
const size_t hd = C / n_head; // head dimension
const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels
size_t hidden_dim = 4 * C;
hidden_dim = (2 * hidden_dim) / 3;
hidden_dim = model->config.ffn_dim_multiplier * hidden_dim;
hidden_dim = model->config.multiple_of * ((hidden_dim + model->config.multiple_of - 1) / model->config.multiple_of);
size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated
size_t ffn_channels_post_gelu = hidden_dim; // swiglu halves the channels

ParameterTensors params = model->params; // for brevity
ParameterTensors grads = model->grads;
Expand Down Expand Up @@ -817,26 +827,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// backward the final layernorm
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream);

// ------------------------------------------------------------------------
// DEBUGGING: we only work until this point right now, so exit here
// transfer the first 32 elements to CPU and print them
float* output = (float*)dresidual;
floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost));
for (int i = 0; i < 32; i++) {
printf("q[%d] = %.8f\n", i, (float) cpu[i]);
}
// write to .bin file
// move output to cpu
floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost));
FILE* f = fopen("out.bin", "wb");
fwrite(cpu_output, sizeof(floatX), B*T*C, f);
fclose(f);
exit(0);
// ------------------------------------------------------------------------

// from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic
// scratch for backward computations
floatX* dl_btc = residual;
Expand All @@ -850,37 +840,36 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// get the pointers of the weights for this layer
floatX* l_ln1w = params.ln1w + l * C;
floatX* l_ln1b = params.ln1b + l * C;
floatX* l_qkvw = params.qkvw + l * 3*C * C;
floatX* l_qkvw = params.qkvw + l * qkv_channels * C;
floatX* l_attprojw = params.attprojw + l * C * C;
floatX* l_ln2w = params.ln2w + l * C;
floatX* l_ln2b = params.ln2b + l * C;
floatX* l_fcw = params.fcw + l * 4*C * C;
floatX* l_fcprojw = params.fcprojw + l * C * 4*C;
floatX* l_fcw = params.fcw + l * ffn_channels * C;
floatX* l_fcprojw = params.fcprojw + l * C * ffn_channels_post_gelu;
// get the pointers of the gradients of the weights for this layer
floatX* dl_ln1w = grads.ln1w + l * C;
floatX* dl_ln1b = grads.ln1b + l * C;
floatX* dl_qkvw = grads.qkvw + l * 3*C * C;
floatX* dl_qkvb = grads.qkvb + l * 3*C;
floatX* dl_qkvw = grads.qkvw + l * qkv_channels * C;
floatX* dl_qkvb = grads.qkvb + l * qkv_channels;
floatX* dl_attprojw = grads.attprojw + l * C * C;
floatX* dl_attprojb = grads.attprojb + l * C;
floatX* dl_ln2w = grads.ln2w + l * C;
floatX* dl_ln2b = grads.ln2b + l * C;
floatX* dl_fcw = grads.fcw + l * 4*C * C;
floatX* dl_fcb = grads.fcb + l * 4*C;
floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
floatX* dl_fcw = grads.fcw + l * ffn_channels * C;
floatX* dl_fcb = grads.fcb + l * ffn_channels;
floatX* dl_fcprojw = grads.fcprojw + l * C * ffn_channels_post_gelu;
floatX* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C;
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;
floatX* l_fch_pre_gelu = acts.fch + l * B * T * ffn_channels;
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu;
// get the pointers of the gradients of the activations for this layer
// notice that there is no l *, because we just have a single copy, and keep
// re-using this memory in every Transformer block as we calculate backward pass
Expand All @@ -891,14 +880,39 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
if(model->recompute >= 1) {
// recompute >= 1 means we recompute gelu. in this case,
// l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here
gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream);
// gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream);
swiglu_forward(l_fch_gelu, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream);
}
matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion);
// backward the 2nd matmul of MLP
matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream);
// backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace
swiglu_backward(scratchX, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream);
// backward the 1st matmul of MLP
if(model->recompute >= 2) {
// same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream);
rmsnorm_forward(l_ln2, l_ln2_rstd, l_residual2, l_ln2w, B, T, C, main_stream);
}
matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream);
matmul_backward(dl_btc, dl_fcw, dl_fcb, scratchX, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream);

// ------------------------------------------------------------------------
// DEBUGGING: we only work until this point right now, so exit here
// transfer the first 32 elements to CPU and print them
float* output = (float*)dl_btc;
floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost));
for (int i = 0; i < 32; i++) {
printf("q[%d] = %.8f\n", i, (float) cpu[i]);
}
// write to .bin file
// move output to cpu
floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost));
FILE* f = fopen("out.bin", "wb");
fwrite(cpu_output, sizeof(floatX), B*T*C, f);
fclose(f);
exit(0);
// ------------------------------------------------------------------------

// layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream);
Expand Down
13 changes: 7 additions & 6 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ def __init__(self, config):

def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
x = x + self.mlp(self.ln_2(x))
MLP_INPUT = self.ln_2(x)
MLP_INPUT = MLP_INPUT.detach()
MLP_INPUT.requires_grad = True
self.MLP_INPUT = MLP_INPUT
x = x + self.mlp(MLP_INPUT)
return x

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -301,10 +305,7 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0):

for i, block in enumerate(self.transformer.h):
x = block(x, freqs_cis, start_pos, mask)

self.DEBUG_INPUT = x.detach()
self.DEBUG_INPUT.requires_grad = True
x = self.transformer.ln_f(self.DEBUG_INPUT)
x = self.transformer.ln_f(x)

if targets is not None:
# if we are given some desired targets also calculate the loss
Expand Down Expand Up @@ -1259,7 +1260,7 @@ def get_lr(it):

# ---------------------------------------------------------------------
# DEBUGGING: print first 32 elements of x
x = model.DEBUG_INPUT.grad
x = model.transformer.h[-1].MLP_INPUT.grad
for i in range(32):
print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item()))
# write to .bin file
Expand Down

0 comments on commit 1b54612

Please sign in to comment.