Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama 3 data loader #736

Merged
merged 10 commits into from
Aug 13, 2024
2 changes: 2 additions & 0 deletions dev/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ The idea is that each dataset has a .py file here in the root of `dev/data`, and
- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it

And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.

Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.
33 changes: 21 additions & 12 deletions dev/data/data_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,36 @@ def download_file(url: str, fname: str, chunk_size=1024):
bar.update(size)


def write_datafile(filename, toks):
HEADERS_INFO = {
karpathy marked this conversation as resolved.
Show resolved Hide resolved
"gpt-2": {
"magic": 20240520,
"version": 1,
},
"llama": {
"magic": 20240801,
"version": 7,
},
}

def write_datafile(filename, toks, model="gpt-2"):
"""
Saves token data as a .bin file, for reading in C.
- First comes a header with 256 int32s
- The tokens follow, each as a uint16
- The tokens follow, each as uint16 (gpt-2) or uint32 (llama)
"""
assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
assert model in ["gpt-2", "llama"], f"unknown model {model}"
# construct the header
header = np.zeros(256, dtype=np.int32)
header[0] = 20240520 # magic
header[1] = 1 # version
header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
# construct the tokens numpy array, if not already
if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
# validate that no token exceeds a uint16
maxtok = 2**16
assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
header[0] = HEADERS_INFO[model]["magic"]
header[1] = HEADERS_INFO[model]["version"]
header[2] = len(toks) # number of tokens after the 256*4 bytes of header
if model == "gpt-2":
toks_np = np.array(toks, dtype=np.uint16)
elif model == "llama":
toks_np = np.array(toks, dtype=np.uint32)
else:
toks_np = toks
# write to file
raise ValueError(f"unknown model {model}")
print(f"writing {len(toks):,} tokens to {filename}")
with open(filename, "wb") as f:
f.write(header.tobytes())
Expand Down
47 changes: 36 additions & 11 deletions dev/data/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@
import os
import argparse
import multiprocessing as mp

import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
import argparse

from transformers import AutoTokenizer
karpathy marked this conversation as resolved.
Show resolved Hide resolved


from data_common import write_datafile
# ------------------------------------------

parser = argparse.ArgumentParser(description="FineWeb and Edu-FineWeb dataset preprocessing")
parser.add_argument("-t", "--type", type=str, default="classic", help="Fineweb type, edu|classic")
parser.add_argument("-v", "--version", type=str, default="10B", help="Fineweb data sample size, 10B|100B")
parser.add_argument("-m", "--model", type=str, default="gpt-2", help="Model type, gpt-2|llama")
parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each data shard in the output .bin files, in tokens")
args = parser.parse_args()

Expand All @@ -60,17 +64,29 @@
fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
name = "edu_fineweb"

# init the tokenizer
enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>'] # end of text token
def tokenize(doc):
# tokenizes a single document and returns a numpy array of uint16 tokens
def tokenize_llama(doc):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
def encode(x):
return tokenizer(x).input_ids
tokens_np = np.array(encode(doc["text"]))

assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), "token dictionary too large for uint32"
tokens_np_uint = tokens_np.astype(np.uint32)
return tokens_np_uint

def tokenize_gpt2(doc):
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
tokens = [eot] # the special <|endoftext|> token delimits all documents
tokens.extend(enc.encode_ordinary(doc["text"]))

# tokenizes a single document and returns a numpy array of uint16 tokens
tokens.extend(encode(doc["text"]))
tokens_np = np.array(tokens)

assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
tokens_np_uint16 = tokens_np.astype(np.uint16)
return tokens_np_uint16
tokens_np_uint = tokens_np.astype(np.uint16)
return tokens_np_uint

# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
Expand All @@ -80,6 +96,15 @@ def tokenize(doc):
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
token_count = 0
progress_bar = None

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't def tokenize break because:

all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)

i.e. the init is using uint16

tokenize = lambda x: None
if args.model == "gpt-2":
tokenize = tokenize_gpt2
elif args.model == "llama":
tokenize = tokenize_llama
else:
raise ValueError(f"unknown model {args.model}")

for tokens in pool.imap(tokenize, fw, chunksize=16):

# is there enough space in the current shard for the new tokens?
Expand All @@ -99,7 +124,7 @@ def tokenize(doc):
remainder = args.shard_size - token_count
progress_bar.update(remainder)
all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
write_datafile(filename, all_tokens_np)
write_datafile(filename, list(all_tokens_np), args.model)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert to list

Copy link
Contributor Author

@gordicaleksa gordicaleksa Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i simplified write_datafile so that it doesn't have to handle both numpy & lists, it's cleaner i think (?)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be many tokens, so creating a Python list could be very wasteful

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> a = np.random.randn(10)
>>> a
array([-1.39200423,  0.91909499,  0.49247546,  0.73578011, -0.46485352,
        0.06844696,  1.21521025,  0.18951044, -0.33376094,  1.03115886])
>>> list(a)
[-1.3920042324598616, 0.9190949922347375, 0.49247545796208686, 0.7357801064341112, -0.4648535191489631, 0.06844695804812885, 1.2152102515229188, 0.18951044050354424, -0.33376094056177236, 1.0311588596558752]
>>> z = list(a)
>>> z[0]
-1.3920042324598616
>>> type(z[0])
<class 'numpy.float64'>

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vs tolist

>>> a = np.random.randn(10)
>>> a
array([-0.28416783,  3.61778557,  0.45557321,  0.6585392 , -0.54974637,
       -0.50662981,  0.36080734,  0.76378507, -1.60443242,  0.41719901])
>>> a.tolist()
[-0.2841678282966355, 3.6177855666263548, 0.45557321422210056, 0.6585391952854299, -0.5497463693208792, -0.5066298099246493, 0.36080734397795633, 0.7637850737170351, -1.6044324246329673, 0.4171990143035489]
>>> z = a.tolist()
>>> z[0]
-0.2841678282966355
>>> type(z[0])
<class 'float'>

shard_index += 1
progress_bar = None
# populate the next shard with the leftovers of the current doc
Expand All @@ -110,4 +135,4 @@ def tokenize(doc):
if token_count != 0:
split = "val" if shard_index == 0 else "train"
filename = os.path.join(DATA_CACHE_DIR, f"{name}_{split}_{shard_index:06d}.bin")
write_datafile(filename, all_tokens_np[:token_count])
write_datafile(filename, list(all_tokens_np[:token_count]), args.model)
40 changes: 29 additions & 11 deletions dev/data/tinyshakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,30 @@
The output is written to a newly created tinyshakespeare/ folder.
The script prints:

Saved 32768 tokens to tinyshakespeare/tiny_shakespeare_val.bin
Saved 305260 tokens to tinyshakespeare/tiny_shakespeare_train.bin
For GPT-2:
writing 32,768 tokens to tinyshakespeare/tiny_shakespeare_val.bin
writing 305,260 tokens to tinyshakespeare/tiny_shakespeare_train.bin

For LLaMA 3:
writing 32,768 tokens to tinyshakespeare/tiny_shakespeare_val.bin
writing 319,555 tokens to tinyshakespeare/tiny_shakespeare_train.bin

And runs in a few seconds depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
"""

import argparse
import os

import tiktoken
import numpy as np
from transformers import AutoTokenizer

from data_common import download_file, write_datafile

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinyshakespeare")

enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})

def download():
"""Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand All @@ -37,7 +42,17 @@ def download():
else:
print(f"{data_filename} already exists, skipping download...")

def tokenize():
def tokenize(model):
if model == "gpt-2":
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})
elif model == "llama":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
def encode(x):
return tokenizer(x).input_ids
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure this now creates bug for this code because <|endoftext|> (below) doesn't tokenize properly

else:
raise ValueError(f"unknown model {model}")

data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
text = open(data_filename, 'r').read()
# let's treat every person's statement in the dialog as a separate document
Expand All @@ -51,9 +66,12 @@ def tokenize():
# save to file
val_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_val.bin")
train_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_train.bin")
write_datafile(val_filename, val_tokens)
write_datafile(train_filename, train_tokens)
write_datafile(val_filename, val_tokens, model)
write_datafile(train_filename, train_tokens, model)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Tiny Shakespeare dataset preprocessing")
parser.add_argument("-m", "--model", type=str, default="gpt-2", choices=["gpt-2", "llama"], help="Model type, gpt-2|llama")
args = parser.parse_args()
download()
tokenize()
tokenize(args.model)
64 changes: 40 additions & 24 deletions dev/data/tinystories.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,45 @@
"""
Downloads and tokenizes the TinyStories dataset.
- The download is from HuggingFace datasets.
- The tokenization is GPT-2 tokenizer with tiktoken
- The tokenization is using either GPT-2 or LLaMA 3 tokenizer.

The output is written to a newly created tinystories/ folder.
The script prints:

For GPT-2:
Number of shards: 50
Tokenizing val split...
Saved 19043638 tokens to tinystories/TinyStories_val.bin
writing 19,043,638 tokens to tinystories/TinyStories_val.bin
Tokenizing train split...
Saved 925653391 tokens to tinystories/TinyStories_train.bin
writing 925,653,391 tokens to tinystories/TinyStories_train.bin

And runs in 1-2 minutes two depending on your internet
For LLaMA 3:
Number of shards: 50
Tokenizing val split...
writing 18,660,516 tokens to tinystories/TinyStories_val.bin
Tokenizing train split...
writing 907,021,844 tokens to tinystories/TinyStories_train.bin

And runs in few minutes two depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
"""

import argparse
import os
import glob
import json
import random
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

import tiktoken
import numpy as np
from transformers import AutoTokenizer

from data_common import download_file, write_datafile

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinystories")

enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)

def download():
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand Down Expand Up @@ -63,22 +70,34 @@ def download():
# data = json.load(f)
# print(f"Example story:\n{data[0]}")

def process_shard(shard_index, shard_filename):
def process_shard(shard_index, shard_filename, model):
if model == "gpt-2":
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
elif model == "llama":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
def encode(x):
return tokenizer(x).input_ids
eot = None
else:
raise ValueError(f"unknown model {model}")

with open(shard_filename, "r") as f:
data = json.load(f)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
rng = random.Random(1337 + shard_index)
rng.shuffle(data)
all_tokens = []
for example in data:
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = encode(text)
all_tokens.append(eot)
if eot is not None:
all_tokens.append(eot)
all_tokens.extend(tokens)
return all_tokens

def tokenize():
def tokenize(model):
# shard 0 will be the val split, rest is train
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
Expand All @@ -89,20 +108,17 @@ def tokenize():
print(f"Tokenizing {split_name} split...")
all_tokens = []
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_shard, shard_index, shard_filename)
futures = [executor.submit(process_shard, shard_index, shard_filename, model)
for shard_index, shard_filename in enumerate(split_shards)]
for future in as_completed(futures):
all_tokens.extend(future.result())

split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
write_datafile(split_filename, all_tokens)
write_datafile(split_filename, all_tokens, model)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Tiny Stories dataset preprocessing")
parser.add_argument("-m", "--model", type=str, default="gpt-2", choices=["gpt-2", "llama"], help="Model type, gpt-2|llama")
args = parser.parse_args()
download()
tokenize()

# Prints:
# Tokenizing val split...
# Saved 19043638 tokens to data/TinyStories_val.bin
# Tokenizing train split...
# Saved 925653391 tokens to data/TinyStories_train.bin
tokenize(args.model)
Loading