-
Notifications
You must be signed in to change notification settings - Fork 319
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
139 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Chat with Llama 3 | ||
|
||
In addition to the Llama 2 example, this is an example how to implement an interactive chat session with Llama 3. This script updates the template used by Llama 3. | ||
|
||
## Installation | ||
You can follow the README in llama2 to setup the environment. | ||
|
||
## Start a chat session | ||
``` | ||
python3 chat.py llama-3-7b-chat-ct2/ | ||
``` | ||
|
||
You can also set a system prompt on the command line: | ||
|
||
``` | ||
python3 chat.py llama-3-7b-chat-ct2/ ["System prompt..."] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import sys | ||
|
||
import ctranslate2 | ||
from transformers import AutoTokenizer | ||
|
||
|
||
def main(): | ||
model_dir = sys.argv[1] | ||
system_prompt = sys.argv[2] if len(sys.argv) > 2 else None | ||
|
||
print("Loading the model...") | ||
generator = ctranslate2.Generator(model_dir, device="cuda") | ||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | ||
|
||
context_length = 4096 | ||
max_generation_length = 512 | ||
max_prompt_length = context_length - max_generation_length | ||
|
||
dialog = [] | ||
|
||
if system_prompt: | ||
dialog.append({"role": "system", "content": system_prompt}) | ||
|
||
while True: | ||
print("") | ||
|
||
user_prompt = input("You: ") | ||
|
||
dialog.append({"role": "user", "content": user_prompt}) | ||
|
||
while True: | ||
prompt_tokens = build_prompt(tokenizer, dialog) | ||
if len(prompt_tokens) <= max_prompt_length: | ||
break | ||
# Remove old conversations to reduce the prompt size. | ||
if system_prompt: | ||
dialog = [dialog[0]] + dialog[3:] | ||
else: | ||
dialog = dialog[2:] | ||
|
||
step_results = generator.generate_tokens( | ||
prompt_tokens, | ||
max_length=max_generation_length, | ||
sampling_temperature=0.6, | ||
sampling_topk=20, | ||
sampling_topp=1, | ||
) | ||
|
||
print("") | ||
print("Llama3: ", end="", flush=True) | ||
|
||
text_output = "" | ||
|
||
for word in generate_words(tokenizer, step_results): | ||
print(word, end="", flush=True) | ||
text_output += word | ||
|
||
print("") | ||
|
||
dialog.append({"role": "assistant", "content": text_output.strip()}) | ||
|
||
|
||
def generate_words(tokenizer, step_results): | ||
tokens_buffer = [] | ||
|
||
for step_result in step_results: | ||
is_new_word = step_result.token.startswith("Ġ") | ||
|
||
if is_new_word and tokens_buffer: | ||
word = tokenizer.decode(tokens_buffer) | ||
if word: | ||
yield word | ||
tokens_buffer = [] | ||
|
||
tokens_buffer.append(step_result.token_id) | ||
|
||
if tokens_buffer: | ||
word = tokenizer.decode(tokens_buffer) | ||
if word: | ||
yield word | ||
|
||
|
||
B_ID, E_ID, E_INST = "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>" | ||
|
||
|
||
def build_prompt(tokenizer, dialog): | ||
begin_pos = 0 | ||
if dialog[0]["role"] == "system": | ||
begin_pos = 1 | ||
assert all([msg["role"] == "user" for msg in dialog[begin_pos::2]]) and all( | ||
[msg["role"] == "assistant" for msg in dialog[begin_pos + 1::2]] | ||
), ( | ||
"model only supports 'system', 'user' and 'assistant' roles, " | ||
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)" | ||
) | ||
|
||
dialog_tokens = sum([ | ||
tokenizer.tokenize( | ||
f"{B_ID} {(item['role'])} {E_ID} {(item['content']).strip()} {E_INST}" | ||
) | ||
for item in dialog | ||
], []) | ||
dialog_tokens = ["<|begin_of_text|>"] + dialog_tokens + tokenizer.tokenize( | ||
f"{B_ID} assistant {E_ID}" | ||
) | ||
|
||
assert ( | ||
dialog[-1]["role"] == "user" | ||
), f"Last message must be from user, got {dialog[-1]['role']}" | ||
|
||
return dialog_tokens | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
ctranslate2>=4.3.0 | ||
transformers[torch]==4.40.* | ||
accelerate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters