Skip to content

Commit

Permalink
Merge pull request #745 from karpathy/feature/managed2
Browse files Browse the repository at this point in the history
feature/managed2
  • Loading branch information
karpathy authored Aug 16, 2024
2 parents 4c84bc7 + 18298f3 commit 0ddedf9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 9 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ endif
ifneq ($(CI),true) # if not in CI, then use the GPU query
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
ifneq ($(call file_exists_in_path, nvidia-smi),)
GPU_COMPUTE_CAPABILITY = $(shell nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g')
# Get the compute capabilities of all GPUs
# Remove decimal points, sort numerically in ascending order, and select the first (lowest) value
GPU_COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g' | sort -n | head -n 1)
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
endif
endif
Expand Down
6 changes: 3 additions & 3 deletions llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ constexpr std::bool_constant<true> False;
// ----------------------------------------------------------------------------
// Error checking

// CUDA error checking
inline void cudaCheck(cudaError_t error, const char *file, int line) {
// CUDA error checking. Underscore added so this function can be called directly not just via macro
inline void cudaCheck_(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
#define cudaCheck(err) (cudaCheck_(err, __FILE__, __LINE__))

// like cudaFree, but checks for errors _and_ resets the pointer.
template<class T>
Expand Down
23 changes: 23 additions & 0 deletions llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,29 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud
cudaCheck(cudaGetLastError());
}

// ----------------------------------------------------------------------------
// memory management

// allocate memory, preferrably on the device
// returns a status code. 0 = OK, 1 = fell back to managed memory
int cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
// try to allocate
cudaError_t err = cudaMalloc(out, bytes);
if(err == cudaErrorMemoryAllocation) {
// if we OOM, fallback to a managed allocation. slower but at least won't crash.
cudaGetLastError(); // reset the error before the next API call
cudaCheck_(cudaMallocManaged(out, bytes), file, line);
cudaCheck_(cudaMemAdvise(*out, bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId), file, line);
return 1;
} else {
cudaCheck_(err, file, line);
return 0;
}
}

#define cudaMallocConditionallyManaged(out, bytes)\
(cudaMallocConditionallyManaged((void**)out, bytes, __FILE__, __LINE__))

// ----------------------------------------------------------------------------
// Random Number Generation used in Stochastic Rounding

Expand Down
22 changes: 17 additions & 5 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/cuda_common.h"
// defines:
// Packed128, f128, x128
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged
#include "llmc/cuda_utils.cuh"
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
// defines: cublas_compute, cublaslt_handle, cublas_handle
Expand Down Expand Up @@ -388,24 +388,36 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) {
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);

// cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device
// and returns a status code of 1 if it had to fall back, in that case we want to print warning.
int memory_status = 0;

// we will now init the optimizer states and master weights
// this is usually a substantial amount of memory allocation right here.
size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for
printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20);
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
assert(model->m_memory == nullptr);
assert(model->v_memory == nullptr);
cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float)));
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));
memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));

if (model->use_master_weights == 1) {
assert(model->master_weights == nullptr);
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**) &model->master_weights, shard_num_parameters * sizeof(float)));
memory_status |= cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));
}

// report on mixed memory allocation status (re-using our float reduce function, bit awk ok)
int reduced_memory_status = (int) multi_gpu_cpu_float_sum((float)memory_status, &multi_gpu_config);
if (reduced_memory_status >= 1) {
printf0("WARNING: Fell back to cudaMallocManaged when initializing m,v,master_weights on %d GPUs\n", reduced_memory_status);
printf0(" Prevents an OOM, but code may run much slower due to device <-> host memory movement\n");
}
// report on device memory usage
size_t free, total;
cudaCheck(cudaMemGetInfo(&free, &total));
printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024);

// give an estimate of the maximum batch size
size_t bytes_per_sequence = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
Expand Down

0 comments on commit 0ddedf9

Please sign in to comment.