Skip to content

Commit 09b47a7

Browse files
committed
llama3 starting point is at gpt-2 exact copy paste for both train/test files
1 parent bd8c604 commit 09b47a7

File tree

2 files changed

+2299
-0
lines changed

2 files changed

+2299
-0
lines changed

test_llama3.cu

+395
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
#define TESTING
2+
#include "train_gpt2.cu"
3+
4+
// poor man's tensor checker
5+
int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) {
6+
// a is the calculated tensor, b is the reference tensor
7+
int print_upto = 10;
8+
int ok = 1;
9+
float max_diff = 0.0f;
10+
float max_rel_error = 0.0f;
11+
float max_to_threshold = 0.f;
12+
float max_a = 0.0f;
13+
float max_b = 0.0f;
14+
float epsilon = 0.079; // BF16 epsilon value
15+
printf("---\n");
16+
printf("checking tensor: %s\n", label);
17+
for (int i = 0; i < n; i++) {
18+
float t_eff = threshold + fabs(b[i]) * epsilon;
19+
float diff = fabsf(a[i] - b[i]);
20+
max_to_threshold = max(max_to_threshold, diff / t_eff);
21+
if (diff > max_diff) {
22+
max_diff = diff;
23+
float denom = fabsf(b[i]);
24+
max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom;
25+
max_a = a[i];
26+
max_b = b[i];
27+
}
28+
if (diff > t_eff) {
29+
ok = 0;
30+
}
31+
// print the first few elements so we can visually assess the "proof" of the comparison
32+
if (i < print_upto) {
33+
printf(diff <= t_eff ? "OK " : "NOT OK ");
34+
printf("%f %f\n", a[i], b[i]);
35+
}
36+
}
37+
// print the final result
38+
if (ok) {
39+
printf("TENSOR OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\n",
40+
max_diff, max_rel_error, max_a, max_b, max_to_threshold*100);
41+
} else {
42+
printf("TENSOR NOT OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\n",
43+
max_diff, max_rel_error, max_a, max_b, max_to_threshold*100);
44+
}
45+
return ok;
46+
}
47+
48+
// the same tensors as in the train file, but in float, which are used as reference
49+
typedef struct {
50+
float* wte; // (Vp, C)
51+
float* wpe; // (maxT, C)
52+
float* ln1w; // (L, C)
53+
float* ln1b; // (L, C)
54+
float* qkvw; // (L, 3*C, C)
55+
float* qkvb; // (L, 3*C)
56+
float* attprojw; // (L, C, C)
57+
float* attprojb; // (L, C)
58+
float* ln2w; // (L, C)
59+
float* ln2b; // (L, C)
60+
float* fcw; // (L, 4*C, C)
61+
float* fcb; // (L, 4*C)
62+
float* fcprojw; // (L, C, 4*C)
63+
float* fcprojb; // (L, C)
64+
float* lnfw; // (C)
65+
float* lnfb; // (C)
66+
} FloatParameterTensors;
67+
static_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!");
68+
69+
// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU
70+
float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) {
71+
// calculate the total number of parameters
72+
size_t num_parameters = 0;
73+
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
74+
num_parameters += param_sizes[i];
75+
}
76+
// everything is float so number of bytes to allocate is a simple multiplication
77+
float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float));
78+
float** ptrs[] = {
79+
&params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,
80+
&params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,
81+
&params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb
82+
};
83+
float* params_memory_iterator = params_memory;
84+
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
85+
*(ptrs[i]) = params_memory_iterator;
86+
params_memory_iterator += param_sizes[i];
87+
}
88+
return params_memory;
89+
}
90+
91+
int main(int argc, char *argv[]) {
92+
char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi"
93+
int num_processes = -1; // doesn't matter when using MPI
94+
int process_rank = -1; // doesn't matter when using MPI
95+
int gpus_per_node = -1; // doesn't matter when using MPI
96+
char server_ip[256] = ""; // doesn't matter when using MPI
97+
char fs_path[256] = ""; // doesn't matter when using MPI
98+
multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);
99+
common_start(false, true);
100+
101+
// set the right paths
102+
#if defined(ENABLE_BF16)
103+
const char* load_filename = "gpt2_124M_bf16.bin";
104+
#else
105+
const char* load_filename = "gpt2_124M.bin";
106+
#endif
107+
108+
// build the GPT-2 model from a checkpoint
109+
GPT2 model;
110+
gpt2_init_common(&model);
111+
112+
gpt2_build_from_checkpoint(&model, load_filename);
113+
size_t V = model.config.vocab_size;
114+
size_t Vp = model.config.padded_vocab_size;
115+
size_t maxT = model.config.max_seq_len;
116+
size_t L = model.config.num_layers;
117+
size_t C = model.config.channels;
118+
119+
for (int i = 1; i < argc; i+=2) {
120+
if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag
121+
if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { exit(EXIT_FAILURE); } // must be -x[y] (one dash, one or two letters)
122+
if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash
123+
if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); }
124+
else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); }
125+
else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); }
126+
}
127+
128+
// load additional information that we will use for debugging and error checking
129+
FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb");
130+
int state_header[256];
131+
freadCheck(state_header, sizeof(int), 256, state_file);
132+
if (state_header[0] != 20240327) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); }
133+
if (state_header[1] != 2) {
134+
fprintf(stderr, "Bad version in state file\n");
135+
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
136+
exit(EXIT_FAILURE);
137+
}
138+
int B = state_header[2]; // batch size, e.g. 4
139+
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
140+
assert(0 <= T && T <= maxT);
141+
printf("[State]\n");
142+
printf("batch_size: %d\n", B);
143+
printf("seq_len: %d\n", T);
144+
145+
set_zero_configs(&multi_gpu_config, 0, model.num_parameters);
146+
147+
// read reference information from the file saved from Python/PyTorch side
148+
// 1) input x and y
149+
int* x = (int*)mallocCheck(B * T * sizeof(int));
150+
int* y = (int*)mallocCheck(B * T * sizeof(int));
151+
freadCheck(x, sizeof(int), B*T, state_file);
152+
freadCheck(y, sizeof(int), B*T, state_file);
153+
// 2) results of forward pass (logits and loss)
154+
float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));
155+
float* expected_loss = (float*) mallocCheck(1 * sizeof(float));
156+
freadCheck(expected_logits, sizeof(float), B*T*V, state_file);
157+
freadCheck(expected_loss, sizeof(float), 1, state_file);
158+
// 3) results of backward pass (parameter gradients)
159+
FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32
160+
float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements);
161+
freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);
162+
fcloseCheck(state_file);
163+
164+
// this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads
165+
void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes);
166+
float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float));
167+
168+
// overall OK signal for the test
169+
int allok = 1;
170+
171+
gpt2_allocate_state(&model, B, T);
172+
173+
// First, do target-free forward pass to validate logits
174+
gpt2_forward(&model, x, B, T);
175+
// at this point, target should be equal to expected_logits, let's compare
176+
// copy logits to CPU so we can compare them
177+
floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX));
178+
float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float));
179+
cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost));
180+
for (int i = 0; i < B * T * Vp; i++) {
181+
logits_cpu[i] = (float)logits_cpu_raw[i];
182+
}
183+
184+
float logit_accuracy_threshold = 1e-3f;
185+
float loss_diff_threshold = 1e-5f;
186+
// FP16 and lower require very high tolerances unfortunately. TODO look into more
187+
#if defined(ENABLE_BF16) || defined(ENABLE_F16)
188+
logit_accuracy_threshold = 25.0f; // 15.0f was too low even without cuDNN?! :(
189+
loss_diff_threshold = 0.05f;
190+
#endif
191+
192+
// compare the output logits from the forward pass
193+
// also careful that we don't access and compare the padded columns of logits
194+
int logits_ok = 1;
195+
float max_diff = 0.0f;
196+
for (int bt = 0; bt < B*T; bt++) {
197+
for (int v = 0; v < V; v++) {
198+
int i = bt * Vp + v; // linearized index
199+
if (i < 10) {
200+
printf("%f, %f\n", expected_logits[i], logits_cpu[i]);
201+
}
202+
float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]);
203+
max_diff = fmaxf(max_diff, diff);
204+
if (diff >= logit_accuracy_threshold) {
205+
printf("MISMATCH AT INDEX %d,%d: ", bt, v);
206+
printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]);
207+
logits_ok = 0;
208+
bt = B*T; // to break out of both loops
209+
break;
210+
}
211+
}
212+
}
213+
allok = allok && logits_ok;
214+
if(!logits_ok) { printf("NOT "); }
215+
printf("OK (LOGITS)\n");
216+
printf("logit max diff: %f\n", max_diff);
217+
218+
// let's do 10 training iterations, following the pytorch code
219+
float losses[10];
220+
for (int step = 0; step < 10; step++) {
221+
struct timespec start, end;
222+
clock_gettime(CLOCK_MONOTONIC, &start);
223+
gpt2_forward(&model, x, B, T);
224+
gpt2_backward_and_reduce(&model, x, y, 1, 0);
225+
clock_gettime(CLOCK_MONOTONIC, &end);
226+
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
227+
228+
if (step == 0) {
229+
// error checking at step 0 for reference activations
230+
231+
// move the (mixed precision) grads from GPU to CPU
232+
cudaCheck(cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost));
233+
234+
// convert all gradients to float on the CPU
235+
char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char*
236+
float* dst_iterator = (float*)grads_memory_cpu_float; // float*
237+
float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python
238+
float* tensors1[NUM_PARAMETER_TENSORS];
239+
float* tensors2[NUM_PARAMETER_TENSORS];
240+
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
241+
if (model.param_sizeof[i] == sizeof(float)) {
242+
// float tensor => copy over directly
243+
memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float));
244+
} else {
245+
// low-precision tensor => convert to float
246+
assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm
247+
for (size_t j = 0; j < model.param_elements[i]; j++) {
248+
dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float
249+
}
250+
}
251+
// for convenience record the position of comparison for reality vs. expectation
252+
tensors1[i] = dst_iterator; // reality
253+
tensors2[i] = exp_iterator; // expectation
254+
// advance the iterators
255+
src_iterator += model.param_elements[i] * model.param_sizeof[i];
256+
dst_iterator += model.param_elements[i];
257+
exp_iterator += model.param_elements[i];
258+
}
259+
260+
// compare the gradients on the parameters all at once, in fp32
261+
// I set the tolerances manually by inspecting the gradient differences for
262+
// a few elements of each tensor. bf16 looks ok but not amazing here.
263+
// It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure.
264+
// Also, if code changes and some of these get tripped, it could be ok if it's not by too much,
265+
// because our use of stochastic rounding is adding some non-determinism "pepper noise".
266+
// In that case it's ok to extend the tolerance by a bit, after a manual review.
267+
// Also, different GPUs may use different matrix multiplication algorithms, so the
268+
// actual errors can be hardware specific.
269+
270+
float grad_thresholds[NUM_PARAMETER_TENSORS] = {5e-1f, 4e-3f, 1e-1f, 3.5e-2f, 2e-2f, 3e-2f, 5e-2f, 5e-2f, 5e-2f, 1.5e-2f, 5e-4f, 8e-3f, 1.5e-3f, 2.5e-3f, 1e-1f, 2e-2f};
271+
#if defined(ENABLE_FP32)
272+
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
273+
grad_thresholds[i] = 1e-6f; // we can be much more precise in FP32
274+
}
275+
#endif
276+
277+
allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", grad_thresholds[0]);
278+
allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", grad_thresholds[1]);
279+
allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", grad_thresholds[2]);
280+
allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", grad_thresholds[3]);
281+
allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", grad_thresholds[4]);
282+
allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", grad_thresholds[5]);
283+
allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", grad_thresholds[6]);
284+
allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", grad_thresholds[7]);
285+
allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", grad_thresholds[8]);
286+
allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", grad_thresholds[9]);
287+
allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", grad_thresholds[10]);
288+
allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", grad_thresholds[11]);
289+
allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", grad_thresholds[12]);
290+
allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", grad_thresholds[13]);
291+
allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", grad_thresholds[14]);
292+
allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]);
293+
}
294+
295+
float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);
296+
float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f;
297+
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config);
298+
299+
// print the timing information at the end
300+
printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000);
301+
// the expected losses from PyTorch were copied over after the print formatting rounded
302+
// them to 6 decimal places, so we do the same here
303+
float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000;
304+
losses[step] = rounded_loss;
305+
}
306+
307+
// expected losses are as follows, from Python
308+
float expected_losses[10] = {
309+
5.270009f,
310+
4.060681f,
311+
3.320085f,
312+
2.717550f,
313+
2.181066f,
314+
1.653923f,
315+
1.168050f,
316+
0.736873f,
317+
0.401021f,
318+
0.187493f
319+
};
320+
321+
// compare
322+
for (int i = 0; i < 10; i++) {
323+
if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) {
324+
printf("LOSS MISMATCH AT STEP %d: %f %f\n", i+1, losses[i], expected_losses[i]);
325+
allok = 0;
326+
} else {
327+
printf("loss ok at step %d: %f %f\n", i+1, losses[i], expected_losses[i]);
328+
}
329+
}
330+
331+
// Finally, let's check determinism
332+
gpt2_write_to_checkpoint(&model, "test_gpt2cu_model.ckpt");
333+
334+
DataLoader loader;
335+
dataloader_init(&loader, "dev/data/tinyshakespeare/tiny_shakespeare_val.bin", B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1);
336+
save_state("test_gpt2cu_state.ckpt", 10, &model, &loader);
337+
int tokens[10];
338+
for (int step = 0; step < 10; step++) {
339+
dataloader_next_batch(&loader);
340+
gpt2_forward(&model, loader.inputs, B, T);
341+
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);
342+
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);
343+
losses[step] = model.mean_loss;
344+
tokens[step] = loader.inputs[0];
345+
}
346+
347+
// reload
348+
gpt2_free(&model);
349+
gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt");
350+
int ld_step;
351+
gpt2_allocate_state(&model, B, T);
352+
load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt");
353+
for (int step = 0; step < 10; step++) {
354+
dataloader_next_batch(&loader);
355+
gpt2_forward(&model, loader.inputs, B, T);
356+
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);
357+
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);
358+
359+
if(loader.inputs[0] != tokens[step]) {
360+
printf("Nondeterminism! Token mismatch at step %d: %d vs %d\n", step, tokens[step], loader.inputs[0]);
361+
allok = false;
362+
break;
363+
}
364+
365+
if(losses[step] != model.mean_loss) {
366+
printf("Nondeterminism! Loss mismatch at step %d: %.15f vs %.15f\n", step, losses[step], model.mean_loss);
367+
allok = false;
368+
break;
369+
} else {
370+
printf("loss ok at step %d: %f %f\n", step, losses[step], model.mean_loss);
371+
}
372+
}
373+
374+
// final approval
375+
printf("overall okay: %d\n", allok);
376+
377+
// delete intermediate test files
378+
remove("test_gpt2cu_model.ckpt");
379+
remove("test_gpt2cu_state.ckpt");
380+
381+
// free everything
382+
dataloader_free(&loader);
383+
gpt2_free(&model);
384+
common_free(model);
385+
free(x);
386+
free(y);
387+
free(logits_cpu_raw);
388+
free(logits_cpu);
389+
free(expected_logits);
390+
free(expected_loss);
391+
free(expected_grads_memory);
392+
free(grads_memory_cpu);
393+
free(grads_memory_cpu_float);
394+
return allok ? EXIT_SUCCESS : EXIT_FAILURE;
395+
}

0 commit comments

Comments
 (0)