Skip to content

Commit

Permalink
take out debugging stuff. we can now run training loop for both model…
Browse files Browse the repository at this point in the history
…s. they don't match yet
  • Loading branch information
karpathy committed Oct 1, 2024
1 parent 9099a0a commit c746e06
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 44 deletions.
13 changes: 8 additions & 5 deletions llmc/encoder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,14 @@ void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu output
NVTX_RANGE_FN();

// Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte)
const int block_size = 256;
const int N = T * C / x128::size;
const int grid_size = CEIL_DIV(N, block_size);
wpe_backward_kernel<<<grid_size, block_size, 0, stream>>>(dwpe, dout, inp, B, T, C, seed);
cudaCheck(cudaGetLastError());
// GPT-2 has wpe (absolute positional encoding), but Llama 3 does not as it uses RoPE
if (dwpe != NULL) {
const int block_size = 256;
const int N = T * C / x128::size;
const int grid_size = CEIL_DIV(N, block_size);
wpe_backward_kernel<<<grid_size, block_size, 0, stream>>>(dwpe, dout, inp, B, T, C, seed);
cudaCheck(cudaGetLastError());
}

// check the GPU scratch buffer is large enough to hold the bucket info and workload indices
// todo - this is trivially true given hardcoded scratch buffer size here, is this useful?
Expand Down
24 changes: 2 additions & 22 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -943,27 +943,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
}
}

// ------------------------------------------------------------------------
// DEBUGGING: we only work until this point right now, so exit here
// transfer the first 32 elements to CPU and print them
float* output = (float*)dresidual;
floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost));
for (int i = 0; i < 32; i++) {
printf("q[%d] = %.8f\n", i, (float) cpu[i]);
}
// write to .bin file
// move output to cpu
int sz = B*T*C;
floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost));
FILE* f = fopen("out.bin", "wb");
fwrite(cpu_output, sizeof(floatX), sz, f);
fclose(f);
exit(0);
// ------------------------------------------------------------------------

encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info,
encoder_backward(grads.wte, NULL, scratchX, model->workload_indices, model->bucket_info,
dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream);

// Aggregate all gradients that are not part of the transformer blocks
Expand All @@ -977,7 +957,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream));
// reduce the gradients for non-transformer block parameters
floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb};
const size_t nelem[] = {Vp * C, T * C, C, C};
const size_t nelem[] = {Vp * C, Vp * C, C, C};
multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);
}

Expand Down
17 changes: 0 additions & 17 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,6 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0):
freqs_cis = self.freqs_cis[start_pos:start_pos+t]
mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1)

DEBUG_POINT = x.detach()
DEBUG_POINT = DEBUG_POINT.requires_grad_(True)
self.DEBUG_POINT = DEBUG_POINT
x = DEBUG_POINT

for i, block in enumerate(self.transformer.h):
x = block(x, freqs_cis, start_pos, mask)
x = self.transformer.ln_f(x)
Expand Down Expand Up @@ -1258,18 +1253,6 @@ def get_lr(it):
# backward pass
if not args.inference_only:
loss.backward()

# ---------------------------------------------------------------------
# DEBUGGING: print first 32 elements of x
x = model.DEBUG_POINT.grad
for i in range(32):
print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item()))
# write to .bin file
with open("ref.bin", "wb") as f:
f.write(x.view(-1).cpu().detach().numpy().tobytes())
breakpoint()
# ---------------------------------------------------------------------

if ddp:
dist.all_reduce(lossf, op=dist.ReduceOp.AVG)
lossf = lossf.item()
Expand Down

0 comments on commit c746e06

Please sign in to comment.