@@ -37,7 +37,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
3737#include " llmc/cuda_common.h"
3838// defines:
3939// Packed128, f128, x128
40- // warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel
40+ // warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged
4141#include " llmc/cuda_utils.cuh"
4242// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
4343// defines: cublas_compute, cublaslt_handle, cublas_handle
@@ -388,24 +388,36 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) {
388388 model->workload_indices = (int *)mallocCheck (sizeof (int ) * model->batch_size * model->seq_len * num_c_groups);
389389 model->bucket_info = (int4 *)mallocCheck (sizeof (int4 ) * model->batch_size * model->seq_len * num_c_groups);
390390
391+ // cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device
392+ // and returns a status code of 1 if it had to fall back, in that case we want to print warning.
393+ int memory_status = 0 ;
394+
395+ // we will now init the optimizer states and master weights
396+ // this is usually a substantial amount of memory allocation right here.
391397 size_t shard_num_parameters = multi_gpu_config.shard_num_parameters ; // num parameters we are responsible for
392398 printf0 (" allocating %zu MiB for AdamW optimizer state m\n " , (shard_num_parameters * sizeof (float )) >> 20 );
393399 printf0 (" allocating %zu MiB for AdamW optimizer state v\n " , (shard_num_parameters * sizeof (float )) >> 20 );
394400 assert (model->m_memory == nullptr );
395401 assert (model->v_memory == nullptr );
396- cudaCheck ( cudaMalloc (( void **)&model->m_memory , shard_num_parameters * sizeof (float ) ));
397- cudaCheck ( cudaMalloc (( void **)&model->v_memory , shard_num_parameters * sizeof (float ) ));
402+ memory_status |= cudaMallocConditionallyManaged (( void **)&model->m_memory , shard_num_parameters * sizeof (float ));
403+ memory_status |= cudaMallocConditionallyManaged (( void **)&model->v_memory , shard_num_parameters * sizeof (float ));
398404
399405 if (model->use_master_weights == 1 ) {
400406 assert (model->master_weights == nullptr );
401407 printf0 (" allocating %zu MiB for master copy of params\n " , (shard_num_parameters * sizeof (float )) >> 20 );
402- cudaCheck ( cudaMalloc (( void **) &model->master_weights , shard_num_parameters * sizeof (float ) ));
408+ memory_status |= cudaMallocConditionallyManaged (( void **) &model->master_weights , shard_num_parameters * sizeof (float ));
403409 }
404410
411+ // report on mixed memory allocation status (re-using our float reduce function, bit awk ok)
412+ int reduced_memory_status = (int ) multi_gpu_cpu_float_sum ((float )memory_status, &multi_gpu_config);
413+ if (reduced_memory_status >= 1 ) {
414+ printf0 (" WARNING: Fell back to cudaMallocManaged when initializing m,v,master_weights on %d GPUs\n " , reduced_memory_status);
415+ printf0 (" Prevents an OOM, but code may run much slower due to device <-> host memory movement\n " );
416+ }
417+ // report on device memory usage
405418 size_t free , total;
406419 cudaCheck (cudaMemGetInfo (&free , &total));
407420 printf0 (" device memory usage: %zd MiB / %zd MiB\n " , (total-free ) / 1024 / 1024 , total / 1024 / 1024 );
408-
409421 // give an estimate of the maximum batch size
410422 size_t bytes_per_sequence = 0 ;
411423 for (size_t i = 0 ; i < NUM_ACTIVATION_TENSORS; i++) {
0 commit comments