Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama 3 data loader #736

Merged
merged 10 commits into from
Aug 13, 2024
Merged

Conversation

gordicaleksa
Copy link
Contributor

@gordicaleksa gordicaleksa commented Aug 10, 2024

Add LLaMA 3 tokenization support for all our datasets:

  1. tiny shakespeare
  2. tiny stories
  3. fineweb

eot = enc._special_tokens['<|endoftext|>'] # end of text token
tokens = [eot] # the special <|endoftext|> token delimits all documents
elif model == "llama":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we not use tiktoken in the exact same way as the official meta release shows? I like that it's more explicit. AutoTokenizer is a black box

Copy link
Contributor Author

@gordicaleksa gordicaleksa Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autotokenizer is a blackbox, I agree, but i believe we can be confident that at least for an architecture as popular as LLaMA 3 HuggingFace is battle-tested!

I prefer that over downloading a tokenizer and passing in a path.

What do you think? (also: we already depend on huggingface eitherway)

# tokenizes a single document and returns a numpy array of uint16 tokens
tokens = [eot] # the special <|endoftext|> token delimits all documents
tokens.extend(enc.encode_ordinary(doc["text"]))
text = doc["text"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this order was intentional, the delimiter should be prepended to docs, so you can inference just starting from that single token.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't change it if you look above? We still have eot followed by encode? (if I understood you correctly)

tokens_np_uint16 = tokens_np.astype(np.uint16)
return tokens_np_uint16

if model == "gpt-2":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor to delete if and copy paste dcode, use ternary operator to set the upper_bound

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and a ternary in the assert statements as well?

@@ -99,7 +120,7 @@ def tokenize(doc):
remainder = args.shard_size - token_count
progress_bar.update(remainder)
all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
write_datafile(filename, all_tokens_np)
write_datafile(filename, list(all_tokens_np), args.model)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert to list

Copy link
Contributor Author

@gordicaleksa gordicaleksa Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i simplified write_datafile so that it doesn't have to handle both numpy & lists, it's cleaner i think (?)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be many tokens, so creating a Python list could be very wasteful

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> a = np.random.randn(10)
>>> a
array([-1.39200423,  0.91909499,  0.49247546,  0.73578011, -0.46485352,
        0.06844696,  1.21521025,  0.18951044, -0.33376094,  1.03115886])
>>> list(a)
[-1.3920042324598616, 0.9190949922347375, 0.49247545796208686, 0.7357801064341112, -0.4648535191489631, 0.06844695804812885, 1.2152102515229188, 0.18951044050354424, -0.33376094056177236, 1.0311588596558752]
>>> z = list(a)
>>> z[0]
-1.3920042324598616
>>> type(z[0])
<class 'numpy.float64'>

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vs tolist

>>> a = np.random.randn(10)
>>> a
array([-0.28416783,  3.61778557,  0.45557321,  0.6585392 , -0.54974637,
       -0.50662981,  0.36080734,  0.76378507, -1.60443242,  0.41719901])
>>> a.tolist()
[-0.2841678282966355, 3.6177855666263548, 0.45557321422210056, 0.6585391952854299, -0.5497463693208792, -0.5066298099246493, 0.36080734397795633, 0.7637850737170351, -1.6044324246329673, 0.4171990143035489]
>>> z = a.tolist()
>>> z[0]
-0.2841678282966355
>>> type(z[0])
<class 'float'>

enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>'] # end of text token
def tokenize(doc):
def tokenize(doc, model):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of taking model and having a big if inside def we can have two defs for the two options, and dispatch accordingly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, at this point, I don't have a special preference as we only support 2 models.

elif model == "llama":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
def encode(x):
return tokenizer(x).input_ids
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure this now creates bug for this code because <|endoftext|> (below) doesn't tokenize properly

@@ -80,6 +96,15 @@ def tokenize(doc):
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
token_count = 0
progress_bar = None

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't def tokenize break because:

all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)

i.e. the init is using uint16

@karpathy karpathy merged commit 9740a65 into karpathy:master Aug 13, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants