Skip to content

Commit

Permalink
Merge pull request #740 from karpathy/gordicaleksa-fix_dataloader2
Browse files Browse the repository at this point in the history
Gordicaleksa fix dataloader2
  • Loading branch information
karpathy authored Aug 13, 2024
2 parents 1787210 + 755458d commit 4c84bc7
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 80 deletions.
2 changes: 2 additions & 0 deletions dev/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ The idea is that each dataset has a .py file here in the root of `dev/data`, and
- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it

And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.

Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.
40 changes: 25 additions & 15 deletions dev/data/data_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,38 @@ def download_file(url: str, fname: str, chunk_size=1024):
bar.update(size)


def write_datafile(filename, toks):
HEADERS_INFO = {
"gpt-2": {
"magic": 20240520,
"version": 1,
"token_dtype": np.uint16,
},
"llama-3": {
"magic": 20240801,
"version": 7,
"token_dtype": np.uint32,
},
}

def write_datafile(filename, toks, model_desc="gpt-2"):
"""
Saves token data as a .bin file, for reading in C.
- First comes a header with 256 int32s
- The tokens follow, each as a uint16
- The tokens follow, each as uint16 (gpt-2) or uint32 (llama)
"""
assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
assert model_desc in ["gpt-2", "llama-3"], f"unknown model descriptor {model_desc}"
info = HEADERS_INFO[model_desc]
# construct the header
header = np.zeros(256, dtype=np.int32)
header[0] = 20240520 # magic
header[1] = 1 # version
header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
# construct the tokens numpy array, if not already
if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
# validate that no token exceeds a uint16
maxtok = 2**16
assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
toks_np = np.array(toks, dtype=np.uint16)
else:
toks_np = toks
header = np.zeros(256, dtype=np.int32) # header is always 256 int32 values
header[0] = info["magic"]
header[1] = info["version"]
header[2] = len(toks) # number of tokens after the 256*4 bytes of header
# construct the data (numpy array of tokens)
toks_np = np.array(toks, dtype=info["token_dtype"])
# write to file
print(f"writing {len(toks):,} tokens to {filename}")
num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize)
print(f"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format")
with open(filename, "wb") as f:
f.write(header.tobytes())
f.write(toks_np.tobytes())
Expand Down
52 changes: 41 additions & 11 deletions dev/data/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@
import os
import argparse
import multiprocessing as mp

import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
import argparse

from transformers import AutoTokenizer


from data_common import write_datafile
# ------------------------------------------

parser = argparse.ArgumentParser(description="FineWeb and Edu-FineWeb dataset preprocessing")
parser.add_argument("-t", "--type", type=str, default="classic", help="Fineweb type, edu|classic")
parser.add_argument("-v", "--version", type=str, default="10B", help="Fineweb data sample size, 10B|100B")
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", help="Model descriptor, gpt-2|llama-3")
parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each data shard in the output .bin files, in tokens")
args = parser.parse_args()

Expand All @@ -60,26 +64,52 @@
fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
name = "edu_fineweb"

# init the tokenizer
enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>'] # end of text token
def tokenize(doc):
def tokenize_llama(doc):
# tokenizes a single document and returns a numpy array of uint32 tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
tokens = [eot] # the special <|endoftext|> token delimits all documents
tokens.extend(encode(doc["text"]))
tokens_np = np.array(tokens)
assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), "token dictionary too large for uint32"
tokens_np_uint = tokens_np.astype(np.uint32)
return tokens_np_uint

def tokenize_gpt2(doc):
# tokenizes a single document and returns a numpy array of uint16 tokens
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
tokens = [eot] # the special <|endoftext|> token delimits all documents
tokens.extend(enc.encode_ordinary(doc["text"]))
tokens.extend(encode(doc["text"]))
tokens_np = np.array(tokens)
assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
tokens_np_uint16 = tokens_np.astype(np.uint16)
return tokens_np_uint16
tokens_np_uint = tokens_np.astype(np.uint16)
return tokens_np_uint

token_dtype = {
"gpt-2": np.uint16,
"llama-3": np.uint32
}[args.model_desc]

# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
with mp.Pool(nprocs) as pool:
shard_index = 0
# preallocate buffer to hold current shard
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
all_tokens_np = np.empty((args.shard_size,), dtype=token_dtype)
token_count = 0
progress_bar = None

tokenize = lambda x: None
if args.model_desc == "gpt-2":
tokenize = tokenize_gpt2
elif args.model_desc == "llama-3":
tokenize = tokenize_llama
else:
raise ValueError(f"unknown model {args.model_desc}")

for tokens in pool.imap(tokenize, fw, chunksize=16):

# is there enough space in the current shard for the new tokens?
Expand All @@ -99,7 +129,7 @@ def tokenize(doc):
remainder = args.shard_size - token_count
progress_bar.update(remainder)
all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
write_datafile(filename, all_tokens_np)
write_datafile(filename, all_tokens_np.tolist(), args.model_desc)
shard_index += 1
progress_bar = None
# populate the next shard with the leftovers of the current doc
Expand All @@ -110,4 +140,4 @@ def tokenize(doc):
if token_count != 0:
split = "val" if shard_index == 0 else "train"
filename = os.path.join(DATA_CACHE_DIR, f"{name}_{split}_{shard_index:06d}.bin")
write_datafile(filename, all_tokens_np[:token_count])
write_datafile(filename, (all_tokens_np[:token_count]).tolist(), args.model_desc)
57 changes: 41 additions & 16 deletions dev/data/tinyshakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,32 @@
The output is written to a newly created tinyshakespeare/ folder.
The script prints:
Saved 32768 tokens to tinyshakespeare/tiny_shakespeare_val.bin
Saved 305260 tokens to tinyshakespeare/tiny_shakespeare_train.bin
For GPT-2:
$ python dev/data/tinyshakespeare.py --model=gpt-2
writing 32,768 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_val.bin (66,560 bytes) in the gpt-2 format
writing 305,260 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_train.bin (611,544 bytes) in the gpt-2 format
For LLaMA 3:
$ python dev/data/tinyshakespeare.py --model=llama-3
writing 32,768 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_val.bin (132,096 bytes) in the llama-3 format
writing 276,224 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_train.bin (1,105,920 bytes) in the llama-3 format
And runs in a few seconds depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
"""

import argparse
import os

import tiktoken
import numpy as np
from transformers import AutoTokenizer

from data_common import download_file, write_datafile

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinyshakespeare")

enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})

def download():
"""Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand All @@ -37,23 +44,41 @@ def download():
else:
print(f"{data_filename} already exists, skipping download...")

def tokenize():
def tokenize(model_desc):
if model_desc == "gpt-2":
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
elif model_desc == "llama-3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
else:
raise ValueError(f"unknown model descriptor {model_desc}")
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
text = open(data_filename, 'r').read()
# let's treat every person's statement in the dialog as a separate document
text = "<|endoftext|>" + text
text = text.replace('\n\n', '\n\n<|endoftext|>')
# encode the text
tokens = encode(text)
# let's treat every individual chunk of text as a separate "document"
sections = text.split("\n\n")
tokens = []
for i, s in enumerate(sections):
tokens.append(eot)
# there was a mild bug where I originally intended to remove \n\n, but instead just added
# the EOT right after each \n\n, so I'm keeping that behavior for backwards compatibility
# therefore we have to here add an extra \n\n at the end of each section, except the last
spad = s + "\n\n" if i != len(sections) - 1 else s
tokens.extend(encode(spad))
# let's take the first 32,768 tokens as the validation split (~10%)
val_tokens = tokens[:32768]
train_tokens = tokens[32768:]
# save to file
val_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_val.bin")
train_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_train.bin")
write_datafile(val_filename, val_tokens)
write_datafile(train_filename, train_tokens)
write_datafile(val_filename, val_tokens, model_desc)
write_datafile(train_filename, train_tokens, model_desc)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Tiny Shakespeare dataset preprocessing")
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", choices=["gpt-2", "llama-3"], help="Model type, gpt-2|llama-3")
args = parser.parse_args()
download()
tokenize()
tokenize(args.model_desc)
60 changes: 37 additions & 23 deletions dev/data/tinystories.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,45 @@
"""
Downloads and tokenizes the TinyStories dataset.
- The download is from HuggingFace datasets.
- The tokenization is GPT-2 tokenizer with tiktoken
- The tokenization is using either GPT-2 or LLaMA 3 tokenizer.
The output is written to a newly created tinystories/ folder.
The script prints:
For GPT-2:
Number of shards: 50
Tokenizing val split...
Saved 19043638 tokens to tinystories/TinyStories_val.bin
writing 19,043,638 tokens to tinystories/TinyStories_val.bin
Tokenizing train split...
Saved 925653391 tokens to tinystories/TinyStories_train.bin
writing 925,653,391 tokens to tinystories/TinyStories_train.bin
And runs in 1-2 minutes two depending on your internet
For LLaMA 3:
Number of shards: 50
Tokenizing val split...
writing 18,660,516 tokens to tinystories/TinyStories_val.bin
Tokenizing train split...
writing 907,021,844 tokens to tinystories/TinyStories_train.bin
And runs in few minutes two depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
"""

import argparse
import os
import glob
import json
import random
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

import tiktoken
import numpy as np
from transformers import AutoTokenizer

from data_common import download_file, write_datafile

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinystories")

enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)

def download():
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand Down Expand Up @@ -63,10 +70,20 @@ def download():
# data = json.load(f)
# print(f"Example story:\n{data[0]}")

def process_shard(shard_index, shard_filename):
def process_shard(shard_index, shard_filename, model_desc):
if model_desc == "gpt-2":
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
elif model_desc == "llama-3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
else:
raise ValueError(f"unknown model descriptor {model_desc}")

with open(shard_filename, "r") as f:
data = json.load(f)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
rng = random.Random(1337 + shard_index)
rng.shuffle(data)
all_tokens = []
Expand All @@ -78,7 +95,7 @@ def process_shard(shard_index, shard_filename):
all_tokens.extend(tokens)
return all_tokens

def tokenize():
def tokenize(model_desc):
# shard 0 will be the val split, rest is train
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
Expand All @@ -89,20 +106,17 @@ def tokenize():
print(f"Tokenizing {split_name} split...")
all_tokens = []
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_shard, shard_index, shard_filename)
futures = [executor.submit(process_shard, shard_index, shard_filename, model_desc)
for shard_index, shard_filename in enumerate(split_shards)]
for future in as_completed(futures):
all_tokens.extend(future.result())

split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
write_datafile(split_filename, all_tokens)
write_datafile(split_filename, all_tokens, model_desc)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Tiny Stories dataset preprocessing")
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", choices=["gpt-2", "llama-3"], help="Model type, gpt-2|llama-3")
args = parser.parse_args()
download()
tokenize()

# Prints:
# Tokenizing val split...
# Saved 19043638 tokens to data/TinyStories_val.bin
# Tokenizing train split...
# Saved 925653391 tokens to data/TinyStories_train.bin
tokenize(args.model_desc)
Loading

0 comments on commit 4c84bc7

Please sign in to comment.