Skip to content

Commit

Permalink
Fix llama 3 (#1671)
Browse files Browse the repository at this point in the history
* fix llama 3

* fix black
  • Loading branch information
minhthuc2502 authored Apr 24, 2024
1 parent 3ba798e commit 042aac7
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 1 deletion.
17 changes: 17 additions & 0 deletions examples/llama3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Chat with Llama 3

In addition to the Llama 2 example, this is an example how to implement an interactive chat session with Llama 3. This script updates the template used by Llama 3.

## Installation
You can follow the README in llama2 to setup the environment.

## Start a chat session
```
python3 chat.py llama-3-7b-chat-ct2/
```

You can also set a system prompt on the command line:

```
python3 chat.py llama-3-7b-chat-ct2/ ["System prompt..."]
```
116 changes: 116 additions & 0 deletions examples/llama3/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import sys

import ctranslate2
from transformers import AutoTokenizer


def main():
model_dir = sys.argv[1]
system_prompt = sys.argv[2] if len(sys.argv) > 2 else None

print("Loading the model...")
generator = ctranslate2.Generator(model_dir, device="cuda")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

context_length = 4096
max_generation_length = 512
max_prompt_length = context_length - max_generation_length

dialog = []

if system_prompt:
dialog.append({"role": "system", "content": system_prompt})

while True:
print("")

user_prompt = input("You: ")

dialog.append({"role": "user", "content": user_prompt})

while True:
prompt_tokens = build_prompt(tokenizer, dialog)
if len(prompt_tokens) <= max_prompt_length:
break
# Remove old conversations to reduce the prompt size.
if system_prompt:
dialog = [dialog[0]] + dialog[3:]
else:
dialog = dialog[2:]

step_results = generator.generate_tokens(
prompt_tokens,
max_length=max_generation_length,
sampling_temperature=0.6,
sampling_topk=20,
sampling_topp=1,
)

print("")
print("Llama3: ", end="", flush=True)

text_output = ""

for word in generate_words(tokenizer, step_results):
print(word, end="", flush=True)
text_output += word

print("")

dialog.append({"role": "assistant", "content": text_output.strip()})


def generate_words(tokenizer, step_results):
tokens_buffer = []

for step_result in step_results:
is_new_word = step_result.token.startswith("Ġ")

if is_new_word and tokens_buffer:
word = tokenizer.decode(tokens_buffer)
if word:
yield word
tokens_buffer = []

tokens_buffer.append(step_result.token_id)

if tokens_buffer:
word = tokenizer.decode(tokens_buffer)
if word:
yield word


B_ID, E_ID, E_INST = "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"


def build_prompt(tokenizer, dialog):
begin_pos = 0
if dialog[0]["role"] == "system":
begin_pos = 1
assert all([msg["role"] == "user" for msg in dialog[begin_pos::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[begin_pos + 1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)

dialog_tokens = sum([
tokenizer.tokenize(
f"{B_ID} {(item['role'])} {E_ID} {(item['content']).strip()} {E_INST}"
)
for item in dialog
], [])
dialog_tokens = ["<|begin_of_text|>"] + dialog_tokens + tokenizer.tokenize(
f"{B_ID} assistant {E_ID}"
)

assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"

return dialog_tokens


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions examples/llama3/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ctranslate2>=4.3.0
transformers[torch]==4.40.*
accelerate
4 changes: 3 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,9 @@ def set_vocabulary(self, spec, tokens):
def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.unk_token = (
tokenizer.unk_token if tokenizer.unk_token is not None else ""
)
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
Expand Down

0 comments on commit 042aac7

Please sign in to comment.