Skip to content

Commit 0cd9f1a

Browse files
authored
Merge pull request #530 from gordicaleksa/minor_refactor
Minor refactor: remove hardcoded val, delete unused vars...
2 parents d17893b + 42ca620 commit 0cd9f1a

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

train_gpt2.cu

+7-9
Original file line numberDiff line numberDiff line change
@@ -1407,7 +1407,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
14071407

14081408
// multiply all elements of preatt elementwise by scale
14091409
float scale = 1.0 / sqrtf(HS);
1410-
int grid_size = CEIL_DIV(B * NH * T * 32, block_size);
1410+
int grid_size = CEIL_DIV(B * NH * T * WARP_SIZE, block_size);
14111411
softmax_forward_kernel5<<<grid_size, block_size>>>(att, scale, preatt, B * NH, T);
14121412

14131413
// new approach: first cuBLAS another batched matmul
@@ -1683,10 +1683,8 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf
16831683
// allocate memory for the parameters and point the individual tensors to the right places
16841684
void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) {
16851685
// calculate the total number of parameters and bytes across all tensors
1686-
size_t num_parameters = 0;
16871686
size_t num_parameters_bytes = 0;
16881687
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
1689-
num_parameters += param_elements[i];
16901688
num_parameters_bytes += param_elements[i] * param_sizeof[i];
16911689
}
16921690
// malloc all parameters all at once on the device
@@ -2433,7 +2431,7 @@ float multi_gpu_cpu_float_sum(float value) {
24332431

24342432
// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled.
24352433
// todo - this version only works if all the parameters are the same size (floatX)
2436-
void gpt2_multi_gpu_accumulate(GPT2* model, MultiGpuConfig* multi_gpu_config) {
2434+
void gpt2_multi_gpu_grad_reduce(GPT2* model, MultiGpuConfig* multi_gpu_config) {
24372435
#ifdef MULTI_GPU
24382436
NVTX_RANGE_FN();
24392437
if (multi_gpu_config->num_processes == 1) { return; }
@@ -2490,12 +2488,12 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl
24902488
// repurposing this buffer (which isn't needed now) to write grad norm into it
24912489
float* grad_norm_squared = (float*)model->acts.output;
24922490
if (multi_gpu_config->zero_stage == 1) {
2493-
// ^1 because of the ncclReduceScatter() in gpt2_multi_gpu_accumulate,
2491+
// ^1 because of the ncclReduceScatter() in gpt2_multi_gpu_grad_reduce,
24942492
// grads_memory only contains the averaged gradients at the local shard
24952493
// so we only calculate the grad norm at the grads_memory belonging to the local shard
24962494
global_norm_squared(grad_norm_squared, grads_memory + shard_offset, shard_num_parameters);
24972495
} else {
2498-
// the ncclAllReduce() in gpt2_multi_gpu_accumulate has averaged the gradients across all GPUs
2496+
// the ncclAllReduce() in gpt2_multi_gpu_grad_reduce has averaged the gradients across all GPUs
24992497
// so each GPU can compute the squared norm over the whole grad vector, with no added comms needed
25002498
global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters);
25012499
}
@@ -2583,7 +2581,7 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl
25832581
return grad_norm_cpu;
25842582
}
25852583

2586-
void gpt2_multi_gpu_gather(GPT2 *model, MultiGpuConfig* multi_gpu_config)
2584+
void gpt2_multi_gpu_param_gather(GPT2 *model, MultiGpuConfig* multi_gpu_config)
25872585
{
25882586
#ifdef MULTI_GPU
25892587
if (multi_gpu_config->num_processes == 1) { return; } // 1 process => noop
@@ -3160,7 +3158,7 @@ int main(int argc, char *argv[]) {
31603158
// this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced
31613159
model.mean_loss = lossf;
31623160
// update the parameters
3163-
gpt2_multi_gpu_accumulate(&model, &multi_gpu_config);
3161+
gpt2_multi_gpu_grad_reduce(&model, &multi_gpu_config);
31643162
// learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
31653163
float step_learning_rate = learning_rate;
31663164
if (step < warmup_iterations) {
@@ -3175,7 +3173,7 @@ int main(int argc, char *argv[]) {
31753173
}
31763174
// update the model parameters
31773175
float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config);
3178-
gpt2_multi_gpu_gather(&model, &multi_gpu_config);
3176+
gpt2_multi_gpu_param_gather(&model, &multi_gpu_config);
31793177
// zero out the gradients for the next iteration
31803178
gpt2_zero_grad(&model);
31813179
cudaCheck(cudaEventRecord(end));

0 commit comments

Comments
 (0)