Skip to content

Commit

Permalink
Merge branch 'master' into swiglu
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Sep 13, 2024
2 parents 92e881a + bd457aa commit 3f7d0cc
Show file tree
Hide file tree
Showing 17 changed files with 1,714 additions and 180 deletions.
12 changes: 9 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ REMOVE_FILES = rm -f
OUTPUT_FILE = -o $@
CUDA_OUTPUT_FILE = -o $@

# Default O3 CPU optimization level for NVCC (0 for fastest compile time)
FORCE_NVCC_O ?= 3

# NVCC flags
# -t=0 is short for --threads, 0 = number of CPUs on the machine
NVCC_FLAGS = -O3 -t=0 --use_fast_math -std=c++17
NVCC_FLAGS = --threads=0 -t=0 --use_fast_math -std=c++17 -O$(FORCE_NVCC_O)
NVCC_LDFLAGS = -lcublas -lcublasLt
NVCC_INCLUDES =
NVCC_LDLIBS =
Expand Down Expand Up @@ -45,8 +48,10 @@ endif

ifneq ($(CI),true) # if not in CI, then use the GPU query
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
ifneq ($(call file_exists_in_path, __nvcc_device_query),)
GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query)
ifneq ($(call file_exists_in_path, nvidia-smi),)
# Get the compute capabilities of all GPUs
# Remove decimal points, sort numerically in ascending order, and select the first (lowest) value
GPU_COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g' | sort -n | head -n 1)
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
endif
endif
Expand All @@ -62,6 +67,7 @@ $(info ---------------------------------------------)

ifneq ($(OS), Windows_NT)
NVCC := $(shell which nvcc 2>/dev/null)
NVCC_LDFLAGS += -lnvidia-ml

# Function to test if the compiler accepts a given flag.
define check_and_add_flag
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,14 @@ Lastly, I will be a lot more sensitive to complexity in the root folder of the p
- [llm.cpp](https://github.com/gevtushenko/llm.c) by @[gevtushenko](https://github.com/gevtushenko): a port of this project using the [CUDA C++ Core Libraries](https://github.com/NVIDIA/cccl)
- A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [CUDA MODE Discord Server](https://discord.gg/cudamode)

- C++/CUDA
- [llm.cpp](https://github.com/zhangpiu/llm.cpp/tree/master/llmcpp) by @[zhangpiu](https://github.com/zhangpiu): a port of this project using the [Eigen](https://gitlab.com/libeigen/eigen), supporting CPU/CUDA.

- WebGPU C++
- [gpu.cpp](https://github.com/AnswerDotAI/gpu.cpp) by @[austinvhuang](https://github.com/austinvhuang): a library for portable GPU compute in C++ using native WebGPU. Aims to be a general-purpose library, but also porting llm.c kernels to WGSL.

- C++
- [llm.cpp](https://github.com/GaoYusong/llm.cpp) by @[GaoYusong](https://github.com/GaoYusong): a port of this project featuring a C++ single-header [tinytorch.hpp](https://github.com/GaoYusong/llm.cpp/blob/main/tinytorch.hpp) library

- Go
- [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project
Expand Down
7 changes: 4 additions & 3 deletions dev/cuda/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ ifeq ($(NVCC),)
endif

ifneq ($(CI),true) # if not in CI, then use the GPU query
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
endif
endif

# Compiler flags
ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=
CFLAGS = -O3 --use_fast_math
ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=
CFLAGS = -O3 --use_fast_math
else
CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)]
endif
Expand All @@ -31,6 +31,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-

# Build all targets
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute

all: $(TARGETS)
all_ptx: $(TARGETS:%=%.ptx)
all_sass: $(TARGETS:%=%.sass)
Expand Down
68 changes: 25 additions & 43 deletions dev/cuda/permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ The linear index in a flattened 1D array is calculated as:
linear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4
This linear index uniquely identifies the position of the element in the 1D array.
To permute the matrix, we need to rearrange the indices according to the new shape.
To permute the matrix, we need to rearrange the indices according to the new shape.
In this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2).
The new dimension post permutation will be as follow:
The new dimension post permutation will be as follows:
dim1 becomes the new 3rd dimension.
dim2 becomes the new 4th dimension.
Expand Down Expand Up @@ -74,7 +74,9 @@ Similarly we can follow the above approach to permute matrices of any dimensions
#include <cuda_runtime.h>
#include <stdio.h>
#include <stdlib.h>
#include <cmath>
#include <cmath>

#include "common.h"

// CPU function to permute a 4D matrix
void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
Expand All @@ -95,9 +97,9 @@ void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int
}

// CUDA kernel to permute a 4D matrix
__global__ void permute_cuda(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
__global__ void permute_kernel(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;

// Ensure index is within bounds
if (idx < dim1 * dim2 * dim3 * dim4) {
// Calculate the 4D indices from the linear index
Expand All @@ -113,32 +115,6 @@ __global__ void permute_cuda(const float* matrix, float* out_matrix, int dim1, i
}
}

// Function to check if the CUDA permutation result matches the CPU result
bool verify_results(const float* permuted_matrix_cuda, const float* permuted_matrix_cpu, int totalElements) {
bool success = true;
for (int i = 0; i < totalElements; i++) {
// Allow a small tolerance for floating-point comparison
if (fabs(permuted_matrix_cuda[i] - permuted_matrix_cpu[i]) > 1e-5) {
success = false;
printf("Permute Operation Failed\n");
printf("CPU: %f\n", permuted_matrix_cpu[i]);
printf("CUDA: %f\n", permuted_matrix_cuda[i]);
break; // Exit early on the first failure
}
}
if (success) {
printf("Permute Operation Passed\n");
}
return success;
}

// Function to initialize the matrix with random values
void initialize_matrix(float* mat, int dim_1, int dim_2, int dim_3, int dim_4) {
for (int i = 0; i < dim_1 * dim_2 * dim_3 * dim_4; ++i) {
mat[i] = static_cast<float>(rand()) / RAND_MAX;
}
printf("Matrix Initialized\n");
}

int main() {
int dim_1 = 24;
Expand All @@ -154,12 +130,10 @@ int main() {
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

// Allocate host memory
float* matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
float* matrix = make_random_float(dim_1 * dim_2 * dim_3 * dim_4);
float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
float* permuted_matrix_cpu = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));

// Initialize the matrix with random values
initialize_matrix(matrix, dim_1, dim_2, dim_3, dim_4);

// Allocate device memory
float *d_matrix, *d_permuted_matrix;
Expand All @@ -170,30 +144,38 @@ int main() {
cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice);

// Perform permutation on CPU
permute_cpu(matrix, permuted_matrix_cpu, dim_1, dim_2, dim_3, dim_4);
clock_t start = clock();
permute_cpu(matrix, permuted_matrix, dim_1, dim_2, dim_3, dim_4);
clock_t end = clock();
double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC;

// Define block and grid sizes
dim3 blockSize(256);
dim3 blockSize(256);
int totalThreads = dim_1 * dim_2 * dim_3 * dim_4;
int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size

// Launch CUDA kernel to perform permutation
permute_cuda<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);
permute_kernel<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);
cudaDeviceSynchronize(); // Ensure kernel execution is complete

// Copy the result from device to host
cudaMemcpy(permuted_matrix, d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyDeviceToHost);

// Verify results
verify_results(permuted_matrix, permuted_matrix_cpu, dim_1 * dim_2 * dim_3 * dim_4);
printf("Checking correctness...\n");
validate_result(d_permuted_matrix, permuted_matrix, "permuted_matrix", dim_1 * dim_2 * dim_3 * dim_4, 1e-5f);

printf("All results match.\n\n");
// benchmark kernel
int repeat_times = 1000;
float elapsed_time = benchmark_kernel(repeat_times, permute_kernel,
d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4
);
printf("time gpu %.4f ms\n", elapsed_time);
printf("time cpu %.4f ms\n", elapsed_time_cpu);

// Free allocated memory
free(matrix);
free(permuted_matrix);
free(permuted_matrix_cpu);
cudaFree(d_matrix);
cudaFree(d_permuted_matrix);

return 0;
}

2 changes: 2 additions & 0 deletions dev/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ The idea is that each dataset has a .py file here in the root of `dev/data`, and
- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it

And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.

Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.
40 changes: 25 additions & 15 deletions dev/data/data_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,38 @@ def download_file(url: str, fname: str, chunk_size=1024):
bar.update(size)


def write_datafile(filename, toks):
HEADERS_INFO = {
"gpt-2": {
"magic": 20240520,
"version": 1,
"token_dtype": np.uint16,
},
"llama-3": {
"magic": 20240801,
"version": 7,
"token_dtype": np.uint32,
},
}

def write_datafile(filename, toks, model_desc="gpt-2"):
"""
Saves token data as a .bin file, for reading in C.
- First comes a header with 256 int32s
- The tokens follow, each as a uint16
- The tokens follow, each as uint16 (gpt-2) or uint32 (llama)
"""
assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
assert model_desc in ["gpt-2", "llama-3"], f"unknown model descriptor {model_desc}"
info = HEADERS_INFO[model_desc]
# construct the header
header = np.zeros(256, dtype=np.int32)
header[0] = 20240520 # magic
header[1] = 1 # version
header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
# construct the tokens numpy array, if not already
if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
# validate that no token exceeds a uint16
maxtok = 2**16
assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
toks_np = np.array(toks, dtype=np.uint16)
else:
toks_np = toks
header = np.zeros(256, dtype=np.int32) # header is always 256 int32 values
header[0] = info["magic"]
header[1] = info["version"]
header[2] = len(toks) # number of tokens after the 256*4 bytes of header
# construct the data (numpy array of tokens)
toks_np = np.array(toks, dtype=info["token_dtype"])
# write to file
print(f"writing {len(toks):,} tokens to {filename}")
num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize)
print(f"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format")
with open(filename, "wb") as f:
f.write(header.tobytes())
f.write(toks_np.tobytes())
Expand Down
52 changes: 41 additions & 11 deletions dev/data/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@
import os
import argparse
import multiprocessing as mp

import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
import argparse

from transformers import AutoTokenizer


from data_common import write_datafile
# ------------------------------------------

parser = argparse.ArgumentParser(description="FineWeb and Edu-FineWeb dataset preprocessing")
parser.add_argument("-t", "--type", type=str, default="classic", help="Fineweb type, edu|classic")
parser.add_argument("-v", "--version", type=str, default="10B", help="Fineweb data sample size, 10B|100B")
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", help="Model descriptor, gpt-2|llama-3")
parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each data shard in the output .bin files, in tokens")
args = parser.parse_args()

Expand All @@ -60,26 +64,52 @@
fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
name = "edu_fineweb"

# init the tokenizer
enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>'] # end of text token
def tokenize(doc):
def tokenize_llama(doc):
# tokenizes a single document and returns a numpy array of uint32 tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
tokens = [eot] # the special <|endoftext|> token delimits all documents
tokens.extend(encode(doc["text"]))
tokens_np = np.array(tokens)
assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), "token dictionary too large for uint32"
tokens_np_uint = tokens_np.astype(np.uint32)
return tokens_np_uint

def tokenize_gpt2(doc):
# tokenizes a single document and returns a numpy array of uint16 tokens
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
tokens = [eot] # the special <|endoftext|> token delimits all documents
tokens.extend(enc.encode_ordinary(doc["text"]))
tokens.extend(encode(doc["text"]))
tokens_np = np.array(tokens)
assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
tokens_np_uint16 = tokens_np.astype(np.uint16)
return tokens_np_uint16
tokens_np_uint = tokens_np.astype(np.uint16)
return tokens_np_uint

token_dtype = {
"gpt-2": np.uint16,
"llama-3": np.uint32
}[args.model_desc]

# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
with mp.Pool(nprocs) as pool:
shard_index = 0
# preallocate buffer to hold current shard
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
all_tokens_np = np.empty((args.shard_size,), dtype=token_dtype)
token_count = 0
progress_bar = None

tokenize = lambda x: None
if args.model_desc == "gpt-2":
tokenize = tokenize_gpt2
elif args.model_desc == "llama-3":
tokenize = tokenize_llama
else:
raise ValueError(f"unknown model {args.model_desc}")

for tokens in pool.imap(tokenize, fw, chunksize=16):

# is there enough space in the current shard for the new tokens?
Expand All @@ -99,7 +129,7 @@ def tokenize(doc):
remainder = args.shard_size - token_count
progress_bar.update(remainder)
all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
write_datafile(filename, all_tokens_np)
write_datafile(filename, all_tokens_np.tolist(), args.model_desc)
shard_index += 1
progress_bar = None
# populate the next shard with the leftovers of the current doc
Expand All @@ -110,4 +140,4 @@ def tokenize(doc):
if token_count != 0:
split = "val" if shard_index == 0 else "train"
filename = os.path.join(DATA_CACHE_DIR, f"{name}_{split}_{shard_index:06d}.bin")
write_datafile(filename, all_tokens_np[:token_count])
write_datafile(filename, (all_tokens_np[:token_count]).tolist(), args.model_desc)
Loading

0 comments on commit 3f7d0cc

Please sign in to comment.