Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-introduce cuda streams #550

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llmc/adamw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,19 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
// this will be used in the next update
if (master_params_memory != NULL) { master_params_memory[idx] = param; }
}

template <typename Tp, typename Tg>
void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
float grad_scale, unsigned int seed, cudaStream_t stream) {
// AdamW update
int block_size = 512;
int num_blocks = CEIL_DIV(num_parameters, block_size);
float beta1_correction = 1.0f - powf(beta1, t);
float beta2_correction = 1.0f - powf(beta2, t);
adamw_kernel3<<<num_blocks, block_size, 0, stream>>>(params_memory, master_params_memory, grads_memory,
m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay,
grad_scale, seed);
cudaCheck(cudaGetLastError());
}
20 changes: 12 additions & 8 deletions llmc/attention.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ __global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const fl

void attention_forward(floatX* out, floatX* qkvr, floatX* att,
floatX* inp,
int B, int T, int C, int NH) {
int B, int T, int C, int NH, cudaStream_t stream) {
NVTX_RANGE_FN();
// Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer.
// Its contents will be overwritten by this function.
Expand All @@ -209,9 +209,11 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
v = qkvr + 2 * B * T * C;
int total_threads = B * NH * T * HS;
int num_blocks = CEIL_DIV(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
permute_kernel<<<num_blocks, block_size, 0, stream>>>(q, k, v, inp, B, T, NH, HS);

floatX* preatt = inp;
cublasCheck(cublasSetStream(cublas_handle, stream));
cublasCheck(cublasSetWorkspace(cublas_handle, cublaslt_workspace, cublaslt_workspace_size));
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
T, T, HS, &alpha,
Expand All @@ -223,7 +225,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
// multiply all elements of preatt elementwise by scale
float scale = 1.0 / sqrtf(HS);
int grid_size = CEIL_DIV(B * NH * T * WARP_SIZE, block_size);
softmax_forward_kernel5<<<grid_size, block_size>>>(att, scale, preatt, B * NH, T);
softmax_forward_kernel5<<<grid_size, block_size, 0, stream>>>(att, scale, preatt, B * NH, T);

// new approach: first cuBLAS another batched matmul
floatX* vaccum = inp;
Expand All @@ -239,7 +241,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
// now unpermute
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = CEIL_DIV(B * T * C, block_size);
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
unpermute_kernel<<<num_blocks, block_size, 0, stream>>>(vaccum, out, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}

Expand All @@ -248,7 +250,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* datt, floatX* scratch,
const floatX* dout,
const floatX* qkvr, const floatX* att,
int B, int T, int C, int NH) {
int B, int T, int C, int NH, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
int HS = C / NH; // head size
Expand All @@ -266,8 +268,10 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da

// backward through the unpermute operation
int num_blocks = CEIL_DIV(B * T * C, block_size);
unpermute_kernel_backward<<<num_blocks, block_size>>>(scratch, dout, B, T, NH, HS);
unpermute_kernel_backward<<<num_blocks, block_size, 0, stream>>>(scratch, dout, B, T, NH, HS);
// backward into datt
cublasCheck(cublasSetStream(cublas_handle, stream));
cublasCheck(cublasSetWorkspace(cublas_handle, cublaslt_workspace, cublaslt_workspace_size));
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha,
v, CUBLAS_LOWP, HS, T * HS, scratch, CUBLAS_LOWP, HS, T * HS, &beta,
datt, CUBLAS_LOWP, T, T * T, B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT));
Expand All @@ -278,7 +282,7 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da
// backward into preatt
int hs = C / NH; // head size
float scale = 1.0f / sqrtf(hs);
softmax_autoregressive_backward_kernel<<<dim3(T / 4, B * NH), 256>>>(dpreatt, datt, att, B, T, C, scale);
softmax_autoregressive_backward_kernel<<<dim3(T / 4, B * NH), 256, 0, stream>>>(dpreatt, datt, att, B, T, C, scale);
// backward into q
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha,
k, CUBLAS_LOWP, HS, T * HS, dpreatt, CUBLAS_LOWP, T, T * T, &beta,
Expand All @@ -289,6 +293,6 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da
dk, CUBLAS_LOWP, HS, T * HS, B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT));
// backward into inp
num_blocks = CEIL_DIV(B * NH * T * HS, block_size);
permute_kernel_backward<<<num_blocks, block_size>>>(dinp, dq, dk, dv, B, T, NH, HS);
permute_kernel_backward<<<num_blocks, block_size, 0, stream>>>(dinp, dq, dk, dv, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}
1 change: 1 addition & 0 deletions llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Common utilities for CUDA code.
#include <string>
#include <cuda_runtime.h>
#include <nvtx3/nvToolsExt.h>
#include <nvtx3/nvToolsExtCudaRt.h>
#include <cuda_profiler_api.h>
#include <cuda_bf16.h>

Expand Down
22 changes: 16 additions & 6 deletions llmc/cudnn_att.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@ static_assert(false, "cuDNN is not supported in FP32 mode.")
static cudnnHandle_t cudnn_handle;
static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!)
static void* cudnn_workspace = NULL;
#define checkCudnnErr(err) assert((int)err == 0);

static void checkCudnnFE(fe::error_object e, const char *file, int line) {
static void cuDNNCheck(cudnnStatus_t error, const char *file, int line) {
if (error != CUDNN_STATUS_SUCCESS) {
printf("[CUDNN ERROR] at file %s:%d:\n%s\n", file, line, cudnnGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cuDNNCheck(err) (cuDNNCheck(err, __FILE__, __LINE__))

static void checkCudnnFE(const fe::error_object& e, const char *file, int line) {
if(!e.is_good()) {
printf("[CUDNN ERROR] at file %s:%d:\n%s\n", file, line, e.err_msg.c_str());
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -211,11 +218,13 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C) {
int B, int T, int NH, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head
bool is_inference_only = (stats == nullptr);

cuDNNCheck(cudnnSetStream(cudnn_handle, stream));

// Get graph and tensors from cache (or generate it on first use)
auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);

Expand All @@ -242,7 +251,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)

void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C) {
int B, int T, int NH, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head

Expand All @@ -269,15 +278,16 @@ void attention_backward_cudnn(floatX* dqkvr,
{Attn_scale_UID, &attn_scale_cpu}};

// Execute graph
cuDNNCheck(cudnnSetStream(cudnn_handle, stream));
checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));
cudaCheck(cudaGetLastError());
}

void create_cudnn() {
checkCudnnErr(cudnnCreate(&cudnn_handle));
cuDNNCheck(cudnnCreate(&cudnn_handle));
}

void destroy_cudnn() {
if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); }
checkCudnnErr(cudnnDestroy(cudnn_handle));
cuDNNCheck(cudnnDestroy(cudnn_handle));
}
4 changes: 2 additions & 2 deletions llmc/cudnn_att.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ void destroy_cudnn();
void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C);
int B, int T, int NH, int C, cudaStream_t stream);

void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C);
int B, int T, int NH, int C, cudaStream_t stream);

#endif // CUDNN_ATT_H
14 changes: 7 additions & 7 deletions llmc/encoder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,27 @@ __global__ void wpe_backward_kernel(floatX* dwpe,

void encoder_forward(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
const int N = B * T * C;
const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size));
encoder_forward_kernel3<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);
encoder_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, inp, wte, wpe, B, T, C);
cudaCheck(cudaGetLastError());
}

// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details)
void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch
int* workload_indices, int4* bucket_info, // cpu scratch buffers
const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
int B, int T, int C, unsigned int seed) {
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
NVTX_RANGE_FN();

// Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte)
const int block_size = 256;
const int N = T * C / x128::size;
const int grid_size = CEIL_DIV(N, block_size);
wpe_backward_kernel<<<grid_size, block_size, 0>>>(dwpe, dout, inp, B, T, C, seed);
wpe_backward_kernel<<<grid_size, block_size, 0, stream>>>(dwpe, dout, inp, B, T, C, seed);
cudaCheck(cudaGetLastError());

// check the GPU scratch buffer is large enough to hold the bucket info and workload indices
Expand Down Expand Up @@ -217,11 +217,11 @@ void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu output
// todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely
int4* d_bucket_info = (int4*)scratch;
int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4));
cudaCheck(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream));
cudaCheck(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream));

// Launch wte kernel
// todo - profile block sizes on more content (depends on number of buckets and on GPU?)
wte_backward_kernel<256><<<num_buckets, 256>>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C);
wte_backward_kernel<256><<<num_buckets, 256, 0, stream>>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C);
cudaCheck(cudaGetLastError());
}
4 changes: 2 additions & 2 deletions llmc/fused_classifier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
template <typename Type>
void fused_classifier(Type* logits, Type* losses,
const float dloss, const int* targets,
int B, int T, int V, int P) {
int B, int T, int V, int P, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 1024;
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel5<<<grid_size, block_size>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P);
fused_classifier_kernel5<<<grid_size, block_size, 0, stream>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P);
cudaCheck(cudaGetLastError());
}
8 changes: 4 additions & 4 deletions llmc/gelu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ __global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp
// ----------------------------------------------------------------------------
// kernel launchers

void gelu_forward(floatX* out, const floatX* inp, int N) {
void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 512;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_forward_kernel2<<<grid_size, block_size>>>(out, inp);
gelu_forward_kernel2<<<grid_size, block_size, 0, stream>>>(out, inp);
cudaCheck(cudaGetLastError());
}

void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N) {
void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 128;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_backward_inplace_kernel<<<grid_size, block_size>>>(d_in_out, inp);
gelu_backward_inplace_kernel<<<grid_size, block_size, 0, stream>>>(d_in_out, inp);
cudaCheck(cudaGetLastError());
}
6 changes: 3 additions & 3 deletions llmc/global_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ __global__ void global_norm_squared_kernel(float* out, const T* data, size_t cou
// kernel launcher

template<typename T>
void global_norm_squared(float* out, const T* values, size_t count) {
void global_norm_squared(float* out, const T* values, size_t count, cudaStream_t stream) {
const int block_size = 512;
// launch just enough blocks to fill the grid. deliberately no DIV_CEIL.
// having one block less than possible is a tiny performance hit, having
Expand All @@ -44,8 +44,8 @@ void global_norm_squared(float* out, const T* values, size_t count) {
const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;
assert(grid_size > 0); // gives a better error than letting the call below fail
// initialize out with zero
cudaCheck(cudaMemset(out, 0, sizeof(float)));
global_norm_squared_kernel<<<grid_size, block_size>>>(out, values, count);
cudaCheck(cudaMemsetAsync(out, 0, sizeof(float), stream));
global_norm_squared_kernel<<<grid_size, block_size, 0, stream>>>(out, values, count);
cudaCheck(cudaGetLastError());
}

25 changes: 13 additions & 12 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -356,28 +356,28 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with

void layernorm_forward(floatX* out, floatX* mean, floatX* rstd,
floatX* inp, const floatX* weight, const floatX* bias,
int B, int T, int C) {
int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 512;
const int N = B * T;
const int grid_size = CEIL_DIV(N * WARP_SIZE, block_size);
layernorm_forward_kernel3<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);
layernorm_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, mean, rstd, inp, weight, bias, N, C);
cudaCheck(cudaGetLastError());
}

void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N) {
void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2);
residual_forward_kernel<<<grid_size, block_size, 0, stream>>>(out, inp1, inp2);
cudaCheck(cudaGetLastError());
}

void fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,
const floatX* inp1, const floatX* inp2,
const floatX* weight, const floatX* bias,
int N, int C) {
int N, int C, cudaStream_t stream) {
const int block_size = 256;
int block_y = block_size / WARP_SIZE;
const int grid_size = CEIL_DIV(N, block_y);
Expand All @@ -389,26 +389,27 @@ void fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, flo
auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaGetLastError();
if(status == cudaSuccess) {
fused_residual_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem>>>(residual, normed, mean, rstd, inp1, inp2,
weight, bias, N, C);
fused_residual_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(residual, normed,
mean, rstd, inp1, inp2,
weight, bias, N, C);
} else {
residual_forward(residual, inp1, inp2, N*C);
layernorm_forward(normed, mean, rstd, residual, weight, bias, N, 1, C);
residual_forward(residual, inp1, inp2, N*C, stream);
layernorm_forward(normed, mean, rstd, residual, weight, bias, N, 1, C, stream);
}
cudaCheck(cudaGetLastError());
}

void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd,
int B, int T, int C) {
int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 512;
const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3
const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount;
size_t rounded_C = CEIL_DIV(C, (32 * x128::size)) * (32 * x128::size);
size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float);

cudaCheck(cudaMemset(scratch, 0, 1 * sizeof(float))); // only need to reset the flag to 0
layernorm_backward_kernel10<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);
cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0
layernorm_backward_kernel10<<<grid_size, block_size, shared_mem_size, stream>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);
cudaCheck(cudaGetLastError());
}
Loading