From c746e06f49e3000039da79fe5a553317fee7e986 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 1 Oct 2024 17:19:39 +0000 Subject: [PATCH] take out debugging stuff. we can now run training loop for both models. they don't match yet --- llmc/encoder.cuh | 13 ++++++++----- train_llama3.cu | 24 ++---------------------- train_llama3.py | 17 ----------------- 3 files changed, 10 insertions(+), 44 deletions(-) diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index fbaf56af1..6ab94e3d4 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -197,11 +197,14 @@ void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu output NVTX_RANGE_FN(); // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) - const int block_size = 256; - const int N = T * C / x128::size; - const int grid_size = CEIL_DIV(N, block_size); - wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); - cudaCheck(cudaGetLastError()); + // GPT-2 has wpe (absolute positional encoding), but Llama 3 does not as it uses RoPE + if (dwpe != NULL) { + const int block_size = 256; + const int N = T * C / x128::size; + const int grid_size = CEIL_DIV(N, block_size); + wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); + cudaCheck(cudaGetLastError()); + } // check the GPU scratch buffer is large enough to hold the bucket info and workload indices // todo - this is trivially true given hardcoded scratch buffer size here, is this useful? diff --git a/train_llama3.cu b/train_llama3.cu index 327802c1d..ccb76cf52 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -943,27 +943,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - float* output = (float*)dresidual; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - // write to .bin file - // move output to cpu - int sz = B*T*C; - floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); - FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), sz, f); - fclose(f); - exit(0); - // ------------------------------------------------------------------------ - - encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + encoder_backward(grads.wte, NULL, scratchX, model->workload_indices, model->bucket_info, dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); // Aggregate all gradients that are not part of the transformer blocks @@ -977,7 +957,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; - const size_t nelem[] = {Vp * C, T * C, C, C}; + const size_t nelem[] = {Vp * C, Vp * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } diff --git a/train_llama3.py b/train_llama3.py index 6cf9c6f69..9a4ee24b3 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -299,11 +299,6 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): freqs_cis = self.freqs_cis[start_pos:start_pos+t] mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) - DEBUG_POINT = x.detach() - DEBUG_POINT = DEBUG_POINT.requires_grad_(True) - self.DEBUG_POINT = DEBUG_POINT - x = DEBUG_POINT - for i, block in enumerate(self.transformer.h): x = block(x, freqs_cis, start_pos, mask) x = self.transformer.ln_f(x) @@ -1258,18 +1253,6 @@ def get_lr(it): # backward pass if not args.inference_only: loss.backward() - - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - x = model.DEBUG_POINT.grad - for i in range(32): - print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) - # write to .bin file - with open("ref.bin", "wb") as f: - f.write(x.view(-1).cpu().detach().numpy().tobytes()) - breakpoint() - # --------------------------------------------------------------------- - if ddp: dist.all_reduce(lossf, op=dist.ReduceOp.AVG) lossf = lossf.item()