Skip to content

Commit

Permalink
Merge pull request #25 from erik-dunteman/llama-quant
Browse files Browse the repository at this point in the history
neuralmagic llama quant
  • Loading branch information
charlesfrye authored Sep 18, 2024
2 parents 9e5ff35 + 471e301 commit c20db30
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""
from pathlib import Path
import modal
import base64

from .xtts import XTTS
from .whisper import Whisper
from .llama import Llama
import base64

from .common import app

Expand Down
2 changes: 1 addition & 1 deletion src/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from modal import App, Secret
from modal import App

app = App(name="quillman")
97 changes: 34 additions & 63 deletions src/llama.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,18 @@
"""
Text generation service based on the Llama 3.1 8B Instruct model.
Text generation service based on the Llama 3.1 8B Instruct model by Meta.
The model is based on the [Meta-Llama](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) model, which is licensed under the Llama3.1 license.
Pulling the model weights from HuggingFace requires Meta org approval.
Follow these steps to optain pull access:
- Go to https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct
- Scroll through the "LLAMA 3.1 COMMUNITY LICENSE AGREEMENT"
- Fill out the form and submit
- Acquire a HuggingFace API token from https://huggingface.co/settings/tokens
- Set that token as a Modal secret with the name "my-huggingface-secret" at https://modal.com/secrets, using the variable name "HF_TOKEN"
Access is usually granted within an hour or two.
The model is an [FP8 quantized version by Neural Magic](https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8), which is licensed under the Llama3.1 license.
We use the [VLLM](https://github.com/vllm-project/vllm) library to run the model.
"""
import json
import time
from pathlib import Path
import os
import re

import modal

from .common import app

MODEL_DIR = "/model"

# Llama 3.1 requires an org approval, usually granted within a few hours
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MODEL_NAME = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
GPU_CONFIG = modal.gpu.A100(size="40GB", count=1)

llama_image = (
Expand All @@ -49,15 +33,14 @@
container_idle_timeout=60 * 10,
allow_concurrent_inputs=10,
image=llama_image,
secrets=[modal.Secret.from_name("my-huggingface-secret")],
)
class Llama:
@modal.build()
def download_model(self):
from huggingface_hub import snapshot_download, login
from huggingface_hub import snapshot_download
from transformers.utils import move_cache
login(os.environ["HF_TOKEN"])

print("Downloading model, this may take a few minutes...")
os.makedirs(MODEL_DIR, exist_ok=True)
snapshot_download(
MODEL_NAME,
Expand All @@ -70,6 +53,7 @@ def download_model(self):
def start_engine(self):
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from transformers import AutoTokenizer
t0 = time.time()

engine_args = AsyncEngineArgs(
Expand All @@ -81,6 +65,8 @@ def start_engine(self):
disable_log_requests=True,
)

self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# this can take some time!
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
print(f"VLLM engine started in {time.time() - t0:.2f}s")
Expand All @@ -91,63 +77,49 @@ def prewarm(self):
pass

@modal.method(is_generator=True)
async def generate(self, input, history=[]):
async def generate(self, prompt, history=[]):
from vllm import SamplingParams
from vllm.utils import random_uuid

stop_token = "<|END|>"
stop_tokens = [stop_token, "Human:"] # prevent model from generating a response to itself
system_prompt = f"You are a helpful AI assistant. Respond to the human to the best of your ability. Keep it brief.. When you have completed your response, end it with the token {stop_token}. For example: Human: What's the capital of France? Assistant: The capital of France is Paris.{stop_token}"

sampling_params = SamplingParams(
temperature=0.75,
max_tokens=128,
repetition_penalty=1.1,
stop=stop_tokens,
include_stop_str_in_output=False,
)

# prepend system message to history
history.insert(0, { "role": "system", "content": system_prompt })

# append current user input to history
history.append({ "role": "user", "content": input })
messages = [
{"role": "system", "content": f"You are a helpful AI assistant. Respond to the human to the best of your ability. Keep it brief."},
]

# Convert chat history to a single string
prompt = ""
for message in history:
role = message["role"]
content = message["content"]
if role == "system":
prompt += f"System: {content}\n"
elif role == "user":
prompt += f"Human: {content}\n"
elif role == "assistant":
prompt += f"Assistant: {content}\n"
for history_entry in history:
# history follows "role" + "content" format so can be used directly
messages.append(history_entry)

# Add the current user input
prompt += f"Human: {input}\n"
prompt += "Assistant: "
messages.append({"role": "user", "content": prompt})

request_id = random_uuid()
print(f"Request {request_id} generating with prompt:{prompt}")
result_stream = self.engine.generate(
prompt,
sampling_params,
request_id,
prompts = self.tokenizer.apply_chat_template(messages, tokenize=False)
sampling_params = SamplingParams(
temperature=0.75,
top_p=0.9,
max_tokens=256,
repetition_penalty=1.1
)

request_id = random_uuid()
print(f"Request {request_id} generating with prompt:\n{prompts}")
result_stream = self.engine.generate(prompts, sampling_params, request_id)
index = 0
buffer = ""
header_complete = False
async for output in result_stream:
if output.outputs[0].text and "\ufffd" == output.outputs[0].text[-1]:
# Skip incomplete unicode characters
continue

new_text = output.outputs[0].text[index:]
buffer += new_text
index = len(output.outputs[0].text)

# ignore leading <|start_header_id|>assistant<|end_header_id|>
if not header_complete:
if new_text == "<|end_header_id|>":
header_complete = True
continue

buffer += new_text

# Yield any complete words in the buffer
while buffer:
space_index = buffer.find(" ")
Expand All @@ -163,7 +135,6 @@ async def generate(self, input, history=[]):
if buffer.strip():
yield buffer.strip()


@app.local_entrypoint()
def main(prompt: str = "Who was Emperor Norton I, and what was his significance in San Francisco's history?"):
model = Llama()
Expand Down
1 change: 1 addition & 0 deletions src/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import time
import modal

from .common import app

cuda_version = "12.4.0" # should be no greater than host CUDA version
Expand Down
1 change: 0 additions & 1 deletion src/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import io
import modal
import time

from .common import app

tts_image = (
Expand Down
14 changes: 8 additions & 6 deletions tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import subprocess
import sys
import base64


import wave
import json
from pathlib import Path

Expand All @@ -29,12 +28,15 @@

# we're simulating a user speaking into a microphone
user_finish_time = None


def user_input_generator():
for wav in files:
wav = Path(__file__).parent / "test-audio" / wav
print(wav)

# sleep for duration of wav file to simulate user speaking
duration = wave.open(wav.as_posix(), "rb").getnframes() / wave.open(wav.as_posix(), "rb").getframerate()
print(f"Simulating user speaking for {duration:.2f} seconds")
time.sleep(duration)

with open(wav, "rb") as f:
yield f.read()

Expand Down Expand Up @@ -132,7 +134,7 @@ async def main():
except websockets.exceptions.WebSocketException:
pass

print(f"Done, output audios saved to /tmp/output_{i}.wav")
print("Done, output audios saved to /tmp/output_*.wav")


if __name__ == "__main__":
Expand Down

0 comments on commit c20db30

Please sign in to comment.