|
| 1 | +#define TESTING |
| 2 | +#include "train_gpt2.cu" |
| 3 | + |
| 4 | +// poor man's tensor checker |
| 5 | +int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) { |
| 6 | + // a is the calculated tensor, b is the reference tensor |
| 7 | + int print_upto = 10; |
| 8 | + int ok = 1; |
| 9 | + float max_diff = 0.0f; |
| 10 | + float max_rel_error = 0.0f; |
| 11 | + float max_to_threshold = 0.f; |
| 12 | + float max_a = 0.0f; |
| 13 | + float max_b = 0.0f; |
| 14 | + float epsilon = 0.079; // BF16 epsilon value |
| 15 | + printf("---\n"); |
| 16 | + printf("checking tensor: %s\n", label); |
| 17 | + for (int i = 0; i < n; i++) { |
| 18 | + float t_eff = threshold + fabs(b[i]) * epsilon; |
| 19 | + float diff = fabsf(a[i] - b[i]); |
| 20 | + max_to_threshold = max(max_to_threshold, diff / t_eff); |
| 21 | + if (diff > max_diff) { |
| 22 | + max_diff = diff; |
| 23 | + float denom = fabsf(b[i]); |
| 24 | + max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom; |
| 25 | + max_a = a[i]; |
| 26 | + max_b = b[i]; |
| 27 | + } |
| 28 | + if (diff > t_eff) { |
| 29 | + ok = 0; |
| 30 | + } |
| 31 | + // print the first few elements so we can visually assess the "proof" of the comparison |
| 32 | + if (i < print_upto) { |
| 33 | + printf(diff <= t_eff ? "OK " : "NOT OK "); |
| 34 | + printf("%f %f\n", a[i], b[i]); |
| 35 | + } |
| 36 | + } |
| 37 | + // print the final result |
| 38 | + if (ok) { |
| 39 | + printf("TENSOR OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\n", |
| 40 | + max_diff, max_rel_error, max_a, max_b, max_to_threshold*100); |
| 41 | + } else { |
| 42 | + printf("TENSOR NOT OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\n", |
| 43 | + max_diff, max_rel_error, max_a, max_b, max_to_threshold*100); |
| 44 | + } |
| 45 | + return ok; |
| 46 | +} |
| 47 | + |
| 48 | +// the same tensors as in the train file, but in float, which are used as reference |
| 49 | +typedef struct { |
| 50 | + float* wte; // (Vp, C) |
| 51 | + float* wpe; // (maxT, C) |
| 52 | + float* ln1w; // (L, C) |
| 53 | + float* ln1b; // (L, C) |
| 54 | + float* qkvw; // (L, 3*C, C) |
| 55 | + float* qkvb; // (L, 3*C) |
| 56 | + float* attprojw; // (L, C, C) |
| 57 | + float* attprojb; // (L, C) |
| 58 | + float* ln2w; // (L, C) |
| 59 | + float* ln2b; // (L, C) |
| 60 | + float* fcw; // (L, 4*C, C) |
| 61 | + float* fcb; // (L, 4*C) |
| 62 | + float* fcprojw; // (L, C, 4*C) |
| 63 | + float* fcprojb; // (L, C) |
| 64 | + float* lnfw; // (C) |
| 65 | + float* lnfb; // (C) |
| 66 | +} FloatParameterTensors; |
| 67 | +static_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); |
| 68 | + |
| 69 | +// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU |
| 70 | +float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) { |
| 71 | + // calculate the total number of parameters |
| 72 | + size_t num_parameters = 0; |
| 73 | + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { |
| 74 | + num_parameters += param_sizes[i]; |
| 75 | + } |
| 76 | + // everything is float so number of bytes to allocate is a simple multiplication |
| 77 | + float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); |
| 78 | + float** ptrs[] = { |
| 79 | + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, |
| 80 | + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, |
| 81 | + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb |
| 82 | + }; |
| 83 | + float* params_memory_iterator = params_memory; |
| 84 | + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { |
| 85 | + *(ptrs[i]) = params_memory_iterator; |
| 86 | + params_memory_iterator += param_sizes[i]; |
| 87 | + } |
| 88 | + return params_memory; |
| 89 | +} |
| 90 | + |
| 91 | +int main(int argc, char *argv[]) { |
| 92 | + char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi" |
| 93 | + int num_processes = -1; // doesn't matter when using MPI |
| 94 | + int process_rank = -1; // doesn't matter when using MPI |
| 95 | + int gpus_per_node = -1; // doesn't matter when using MPI |
| 96 | + char server_ip[256] = ""; // doesn't matter when using MPI |
| 97 | + char fs_path[256] = ""; // doesn't matter when using MPI |
| 98 | + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); |
| 99 | + common_start(false, true); |
| 100 | + |
| 101 | + // set the right paths |
| 102 | + #if defined(ENABLE_BF16) |
| 103 | + const char* load_filename = "gpt2_124M_bf16.bin"; |
| 104 | + #else |
| 105 | + const char* load_filename = "gpt2_124M.bin"; |
| 106 | + #endif |
| 107 | + |
| 108 | + // build the GPT-2 model from a checkpoint |
| 109 | + GPT2 model; |
| 110 | + gpt2_init_common(&model); |
| 111 | + |
| 112 | + gpt2_build_from_checkpoint(&model, load_filename); |
| 113 | + size_t V = model.config.vocab_size; |
| 114 | + size_t Vp = model.config.padded_vocab_size; |
| 115 | + size_t maxT = model.config.max_seq_len; |
| 116 | + size_t L = model.config.num_layers; |
| 117 | + size_t C = model.config.channels; |
| 118 | + |
| 119 | + for (int i = 1; i < argc; i+=2) { |
| 120 | + if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag |
| 121 | + if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { exit(EXIT_FAILURE); } // must be -x[y] (one dash, one or two letters) |
| 122 | + if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash |
| 123 | + if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); } |
| 124 | + else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); } |
| 125 | + else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); } |
| 126 | + } |
| 127 | + |
| 128 | + // load additional information that we will use for debugging and error checking |
| 129 | + FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); |
| 130 | + int state_header[256]; |
| 131 | + freadCheck(state_header, sizeof(int), 256, state_file); |
| 132 | + if (state_header[0] != 20240327) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); } |
| 133 | + if (state_header[1] != 2) { |
| 134 | + fprintf(stderr, "Bad version in state file\n"); |
| 135 | + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); |
| 136 | + exit(EXIT_FAILURE); |
| 137 | + } |
| 138 | + int B = state_header[2]; // batch size, e.g. 4 |
| 139 | + int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) |
| 140 | + assert(0 <= T && T <= maxT); |
| 141 | + printf("[State]\n"); |
| 142 | + printf("batch_size: %d\n", B); |
| 143 | + printf("seq_len: %d\n", T); |
| 144 | + |
| 145 | + set_zero_configs(&multi_gpu_config, 0, model.num_parameters); |
| 146 | + |
| 147 | + // read reference information from the file saved from Python/PyTorch side |
| 148 | + // 1) input x and y |
| 149 | + int* x = (int*)mallocCheck(B * T * sizeof(int)); |
| 150 | + int* y = (int*)mallocCheck(B * T * sizeof(int)); |
| 151 | + freadCheck(x, sizeof(int), B*T, state_file); |
| 152 | + freadCheck(y, sizeof(int), B*T, state_file); |
| 153 | + // 2) results of forward pass (logits and loss) |
| 154 | + float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float)); |
| 155 | + float* expected_loss = (float*) mallocCheck(1 * sizeof(float)); |
| 156 | + freadCheck(expected_logits, sizeof(float), B*T*V, state_file); |
| 157 | + freadCheck(expected_loss, sizeof(float), 1, state_file); |
| 158 | + // 3) results of backward pass (parameter gradients) |
| 159 | + FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32 |
| 160 | + float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements); |
| 161 | + freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file); |
| 162 | + fcloseCheck(state_file); |
| 163 | + |
| 164 | + // this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads |
| 165 | + void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes); |
| 166 | + float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float)); |
| 167 | + |
| 168 | + // overall OK signal for the test |
| 169 | + int allok = 1; |
| 170 | + |
| 171 | + gpt2_allocate_state(&model, B, T); |
| 172 | + |
| 173 | + // First, do target-free forward pass to validate logits |
| 174 | + gpt2_forward(&model, x, B, T); |
| 175 | + // at this point, target should be equal to expected_logits, let's compare |
| 176 | + // copy logits to CPU so we can compare them |
| 177 | + floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); |
| 178 | + float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float)); |
| 179 | + cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost)); |
| 180 | + for (int i = 0; i < B * T * Vp; i++) { |
| 181 | + logits_cpu[i] = (float)logits_cpu_raw[i]; |
| 182 | + } |
| 183 | + |
| 184 | + float logit_accuracy_threshold = 1e-3f; |
| 185 | + float loss_diff_threshold = 1e-5f; |
| 186 | + // FP16 and lower require very high tolerances unfortunately. TODO look into more |
| 187 | + #if defined(ENABLE_BF16) || defined(ENABLE_F16) |
| 188 | + logit_accuracy_threshold = 25.0f; // 15.0f was too low even without cuDNN?! :( |
| 189 | + loss_diff_threshold = 0.05f; |
| 190 | + #endif |
| 191 | + |
| 192 | + // compare the output logits from the forward pass |
| 193 | + // also careful that we don't access and compare the padded columns of logits |
| 194 | + int logits_ok = 1; |
| 195 | + float max_diff = 0.0f; |
| 196 | + for (int bt = 0; bt < B*T; bt++) { |
| 197 | + for (int v = 0; v < V; v++) { |
| 198 | + int i = bt * Vp + v; // linearized index |
| 199 | + if (i < 10) { |
| 200 | + printf("%f, %f\n", expected_logits[i], logits_cpu[i]); |
| 201 | + } |
| 202 | + float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]); |
| 203 | + max_diff = fmaxf(max_diff, diff); |
| 204 | + if (diff >= logit_accuracy_threshold) { |
| 205 | + printf("MISMATCH AT INDEX %d,%d: ", bt, v); |
| 206 | + printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]); |
| 207 | + logits_ok = 0; |
| 208 | + bt = B*T; // to break out of both loops |
| 209 | + break; |
| 210 | + } |
| 211 | + } |
| 212 | + } |
| 213 | + allok = allok && logits_ok; |
| 214 | + if(!logits_ok) { printf("NOT "); } |
| 215 | + printf("OK (LOGITS)\n"); |
| 216 | + printf("logit max diff: %f\n", max_diff); |
| 217 | + |
| 218 | + // let's do 10 training iterations, following the pytorch code |
| 219 | + float losses[10]; |
| 220 | + for (int step = 0; step < 10; step++) { |
| 221 | + struct timespec start, end; |
| 222 | + clock_gettime(CLOCK_MONOTONIC, &start); |
| 223 | + gpt2_forward(&model, x, B, T); |
| 224 | + gpt2_backward_and_reduce(&model, x, y, 1, 0); |
| 225 | + clock_gettime(CLOCK_MONOTONIC, &end); |
| 226 | + double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; |
| 227 | + |
| 228 | + if (step == 0) { |
| 229 | + // error checking at step 0 for reference activations |
| 230 | + |
| 231 | + // move the (mixed precision) grads from GPU to CPU |
| 232 | + cudaCheck(cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost)); |
| 233 | + |
| 234 | + // convert all gradients to float on the CPU |
| 235 | + char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char* |
| 236 | + float* dst_iterator = (float*)grads_memory_cpu_float; // float* |
| 237 | + float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python |
| 238 | + float* tensors1[NUM_PARAMETER_TENSORS]; |
| 239 | + float* tensors2[NUM_PARAMETER_TENSORS]; |
| 240 | + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { |
| 241 | + if (model.param_sizeof[i] == sizeof(float)) { |
| 242 | + // float tensor => copy over directly |
| 243 | + memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float)); |
| 244 | + } else { |
| 245 | + // low-precision tensor => convert to float |
| 246 | + assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm |
| 247 | + for (size_t j = 0; j < model.param_elements[i]; j++) { |
| 248 | + dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float |
| 249 | + } |
| 250 | + } |
| 251 | + // for convenience record the position of comparison for reality vs. expectation |
| 252 | + tensors1[i] = dst_iterator; // reality |
| 253 | + tensors2[i] = exp_iterator; // expectation |
| 254 | + // advance the iterators |
| 255 | + src_iterator += model.param_elements[i] * model.param_sizeof[i]; |
| 256 | + dst_iterator += model.param_elements[i]; |
| 257 | + exp_iterator += model.param_elements[i]; |
| 258 | + } |
| 259 | + |
| 260 | + // compare the gradients on the parameters all at once, in fp32 |
| 261 | + // I set the tolerances manually by inspecting the gradient differences for |
| 262 | + // a few elements of each tensor. bf16 looks ok but not amazing here. |
| 263 | + // It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure. |
| 264 | + // Also, if code changes and some of these get tripped, it could be ok if it's not by too much, |
| 265 | + // because our use of stochastic rounding is adding some non-determinism "pepper noise". |
| 266 | + // In that case it's ok to extend the tolerance by a bit, after a manual review. |
| 267 | + // Also, different GPUs may use different matrix multiplication algorithms, so the |
| 268 | + // actual errors can be hardware specific. |
| 269 | + |
| 270 | + float grad_thresholds[NUM_PARAMETER_TENSORS] = {5e-1f, 4e-3f, 1e-1f, 3.5e-2f, 2e-2f, 3e-2f, 5e-2f, 5e-2f, 5e-2f, 1.5e-2f, 5e-4f, 8e-3f, 1.5e-3f, 2.5e-3f, 1e-1f, 2e-2f}; |
| 271 | + #if defined(ENABLE_FP32) |
| 272 | + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { |
| 273 | + grad_thresholds[i] = 1e-6f; // we can be much more precise in FP32 |
| 274 | + } |
| 275 | + #endif |
| 276 | + |
| 277 | + allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", grad_thresholds[0]); |
| 278 | + allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", grad_thresholds[1]); |
| 279 | + allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", grad_thresholds[2]); |
| 280 | + allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", grad_thresholds[3]); |
| 281 | + allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", grad_thresholds[4]); |
| 282 | + allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", grad_thresholds[5]); |
| 283 | + allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", grad_thresholds[6]); |
| 284 | + allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", grad_thresholds[7]); |
| 285 | + allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", grad_thresholds[8]); |
| 286 | + allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", grad_thresholds[9]); |
| 287 | + allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", grad_thresholds[10]); |
| 288 | + allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", grad_thresholds[11]); |
| 289 | + allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", grad_thresholds[12]); |
| 290 | + allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", grad_thresholds[13]); |
| 291 | + allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", grad_thresholds[14]); |
| 292 | + allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]); |
| 293 | + } |
| 294 | + |
| 295 | + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); |
| 296 | + float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; |
| 297 | + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); |
| 298 | + |
| 299 | + // print the timing information at the end |
| 300 | + printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); |
| 301 | + // the expected losses from PyTorch were copied over after the print formatting rounded |
| 302 | + // them to 6 decimal places, so we do the same here |
| 303 | + float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000; |
| 304 | + losses[step] = rounded_loss; |
| 305 | + } |
| 306 | + |
| 307 | + // expected losses are as follows, from Python |
| 308 | + float expected_losses[10] = { |
| 309 | + 5.270009f, |
| 310 | + 4.060681f, |
| 311 | + 3.320085f, |
| 312 | + 2.717550f, |
| 313 | + 2.181066f, |
| 314 | + 1.653923f, |
| 315 | + 1.168050f, |
| 316 | + 0.736873f, |
| 317 | + 0.401021f, |
| 318 | + 0.187493f |
| 319 | + }; |
| 320 | + |
| 321 | + // compare |
| 322 | + for (int i = 0; i < 10; i++) { |
| 323 | + if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) { |
| 324 | + printf("LOSS MISMATCH AT STEP %d: %f %f\n", i+1, losses[i], expected_losses[i]); |
| 325 | + allok = 0; |
| 326 | + } else { |
| 327 | + printf("loss ok at step %d: %f %f\n", i+1, losses[i], expected_losses[i]); |
| 328 | + } |
| 329 | + } |
| 330 | + |
| 331 | + // Finally, let's check determinism |
| 332 | + gpt2_write_to_checkpoint(&model, "test_gpt2cu_model.ckpt"); |
| 333 | + |
| 334 | + DataLoader loader; |
| 335 | + dataloader_init(&loader, "dev/data/tinyshakespeare/tiny_shakespeare_val.bin", B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1); |
| 336 | + save_state("test_gpt2cu_state.ckpt", 10, &model, &loader); |
| 337 | + int tokens[10]; |
| 338 | + for (int step = 0; step < 10; step++) { |
| 339 | + dataloader_next_batch(&loader); |
| 340 | + gpt2_forward(&model, loader.inputs, B, T); |
| 341 | + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); |
| 342 | + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); |
| 343 | + losses[step] = model.mean_loss; |
| 344 | + tokens[step] = loader.inputs[0]; |
| 345 | + } |
| 346 | + |
| 347 | + // reload |
| 348 | + gpt2_free(&model); |
| 349 | + gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt"); |
| 350 | + int ld_step; |
| 351 | + gpt2_allocate_state(&model, B, T); |
| 352 | + load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt"); |
| 353 | + for (int step = 0; step < 10; step++) { |
| 354 | + dataloader_next_batch(&loader); |
| 355 | + gpt2_forward(&model, loader.inputs, B, T); |
| 356 | + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); |
| 357 | + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); |
| 358 | + |
| 359 | + if(loader.inputs[0] != tokens[step]) { |
| 360 | + printf("Nondeterminism! Token mismatch at step %d: %d vs %d\n", step, tokens[step], loader.inputs[0]); |
| 361 | + allok = false; |
| 362 | + break; |
| 363 | + } |
| 364 | + |
| 365 | + if(losses[step] != model.mean_loss) { |
| 366 | + printf("Nondeterminism! Loss mismatch at step %d: %.15f vs %.15f\n", step, losses[step], model.mean_loss); |
| 367 | + allok = false; |
| 368 | + break; |
| 369 | + } else { |
| 370 | + printf("loss ok at step %d: %f %f\n", step, losses[step], model.mean_loss); |
| 371 | + } |
| 372 | + } |
| 373 | + |
| 374 | + // final approval |
| 375 | + printf("overall okay: %d\n", allok); |
| 376 | + |
| 377 | + // delete intermediate test files |
| 378 | + remove("test_gpt2cu_model.ckpt"); |
| 379 | + remove("test_gpt2cu_state.ckpt"); |
| 380 | + |
| 381 | + // free everything |
| 382 | + dataloader_free(&loader); |
| 383 | + gpt2_free(&model); |
| 384 | + common_free(model); |
| 385 | + free(x); |
| 386 | + free(y); |
| 387 | + free(logits_cpu_raw); |
| 388 | + free(logits_cpu); |
| 389 | + free(expected_logits); |
| 390 | + free(expected_loss); |
| 391 | + free(expected_grads_memory); |
| 392 | + free(grads_memory_cpu); |
| 393 | + free(grads_memory_cpu_float); |
| 394 | + return allok ? EXIT_SUCCESS : EXIT_FAILURE; |
| 395 | +} |
0 commit comments