Skip to content

Commit 0ddedf9

Browse files
authored
Merge pull request #745 from karpathy/feature/managed2
feature/managed2
2 parents 4c84bc7 + 18298f3 commit 0ddedf9

File tree

4 files changed

+46
-9
lines changed

4 files changed

+46
-9
lines changed

Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ endif
4949
ifneq ($(CI),true) # if not in CI, then use the GPU query
5050
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
5151
ifneq ($(call file_exists_in_path, nvidia-smi),)
52-
GPU_COMPUTE_CAPABILITY = $(shell nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g')
52+
# Get the compute capabilities of all GPUs
53+
# Remove decimal points, sort numerically in ascending order, and select the first (lowest) value
54+
GPU_COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g' | sort -n | head -n 1)
5355
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
5456
endif
5557
endif

llmc/cuda_common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ constexpr std::bool_constant<true> False;
4848
// ----------------------------------------------------------------------------
4949
// Error checking
5050

51-
// CUDA error checking
52-
inline void cudaCheck(cudaError_t error, const char *file, int line) {
51+
// CUDA error checking. Underscore added so this function can be called directly not just via macro
52+
inline void cudaCheck_(cudaError_t error, const char *file, int line) {
5353
if (error != cudaSuccess) {
5454
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
5555
exit(EXIT_FAILURE);
5656
}
5757
};
58-
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
58+
#define cudaCheck(err) (cudaCheck_(err, __FILE__, __LINE__))
5959

6060
// like cudaFree, but checks for errors _and_ resets the pointer.
6161
template<class T>

llmc/cuda_utils.cuh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,29 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud
205205
cudaCheck(cudaGetLastError());
206206
}
207207

208+
// ----------------------------------------------------------------------------
209+
// memory management
210+
211+
// allocate memory, preferrably on the device
212+
// returns a status code. 0 = OK, 1 = fell back to managed memory
213+
int cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
214+
// try to allocate
215+
cudaError_t err = cudaMalloc(out, bytes);
216+
if(err == cudaErrorMemoryAllocation) {
217+
// if we OOM, fallback to a managed allocation. slower but at least won't crash.
218+
cudaGetLastError(); // reset the error before the next API call
219+
cudaCheck_(cudaMallocManaged(out, bytes), file, line);
220+
cudaCheck_(cudaMemAdvise(*out, bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId), file, line);
221+
return 1;
222+
} else {
223+
cudaCheck_(err, file, line);
224+
return 0;
225+
}
226+
}
227+
228+
#define cudaMallocConditionallyManaged(out, bytes)\
229+
(cudaMallocConditionallyManaged((void**)out, bytes, __FILE__, __LINE__))
230+
208231
// ----------------------------------------------------------------------------
209232
// Random Number Generation used in Stochastic Rounding
210233

train_gpt2.cu

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
3737
#include "llmc/cuda_common.h"
3838
// defines:
3939
// Packed128, f128, x128
40-
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel
40+
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged
4141
#include "llmc/cuda_utils.cuh"
4242
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
4343
// defines: cublas_compute, cublaslt_handle, cublas_handle
@@ -388,24 +388,36 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) {
388388
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
389389
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);
390390

391+
// cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device
392+
// and returns a status code of 1 if it had to fall back, in that case we want to print warning.
393+
int memory_status = 0;
394+
395+
// we will now init the optimizer states and master weights
396+
// this is usually a substantial amount of memory allocation right here.
391397
size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for
392398
printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20);
393399
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
394400
assert(model->m_memory == nullptr);
395401
assert(model->v_memory == nullptr);
396-
cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float)));
397-
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
402+
memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));
403+
memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));
398404

399405
if (model->use_master_weights == 1) {
400406
assert(model->master_weights == nullptr);
401407
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
402-
cudaCheck(cudaMalloc((void**) &model->master_weights, shard_num_parameters * sizeof(float)));
408+
memory_status |= cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));
403409
}
404410

411+
// report on mixed memory allocation status (re-using our float reduce function, bit awk ok)
412+
int reduced_memory_status = (int) multi_gpu_cpu_float_sum((float)memory_status, &multi_gpu_config);
413+
if (reduced_memory_status >= 1) {
414+
printf0("WARNING: Fell back to cudaMallocManaged when initializing m,v,master_weights on %d GPUs\n", reduced_memory_status);
415+
printf0(" Prevents an OOM, but code may run much slower due to device <-> host memory movement\n");
416+
}
417+
// report on device memory usage
405418
size_t free, total;
406419
cudaCheck(cudaMemGetInfo(&free, &total));
407420
printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024);
408-
409421
// give an estimate of the maximum batch size
410422
size_t bytes_per_sequence = 0;
411423
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {

0 commit comments

Comments
 (0)