diff --git a/llm_exl2_dynamic_gen.py b/llm_exl2_dynamic_gen.py index d4135c0..fa21706 100644 --- a/llm_exl2_dynamic_gen.py +++ b/llm_exl2_dynamic_gen.py @@ -37,8 +37,7 @@ import uuid from blessed import Terminal import textwrap -from outlines.integrations.exllamav2 import RegexFilter, TextFilter, JSONFilter, ChoiceFilter -from util import format_prompt_llama3, format_prompt, format_prompt_tess, format_prompt_commandr +from outlines.integrations.exllamav2 import RegexFilter, JSONFilter, ChoiceFilter from util_merge import ExLlamaV2MergePassthrough def generate_unique_id(): @@ -628,31 +627,18 @@ def process_prompts(): worker = Thread(target=process_prompts) worker.start() +def get_messages(messages): + output = [] + for message in messages: + output.append({"role": message.role, "content": message.content}) + return output @app.post('/v1/chat/completions') async def mainchat(requestid: Request, request: ChatCompletionRequest): try: - prompt = '' - if repo_str == 'Phind-CodeLlama-34B-v2': - prompt = await format_prompt_code(request.messages) - elif repo_str == 'zephyr-7b-beta': - prompt = await format_prompt_zephyr(request.messages) - elif repo_str == 'llama3-70b-instruct' or 'llama3-70b-instruct-speculative': - prompt = await format_prompt_llama3(request.messages) - elif repo_str == 'Starling-LM-7B-alpha': - prompt = await format_prompt_starling(request.messages) - elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ': - prompt = await format_prompt_mixtral(request.messages) - elif repo_str == 'Yi-34B-Chat-GPTQ' or repo_str == 'Nous-Hermes-2-Yi-34B-GPTQ' or repo_str == 'theprofessor-exl2-speculative' or repo_str == 'dbrx-instruct-exl2': - prompt = await format_prompt_yi(request.messages) - elif repo_str == 'Nous-Capybara-34B-GPTQ' or repo_str == 'goliath-120b-GPTQ' or repo_str == 'goliath-120b-exl2' or repo_str == 'goliath-120b-exl2-rpcal': - prompt = await format_prompt_nous(request.messages) - elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative': - prompt = await format_prompt_tess(request.messages) - elif repo_str == 'commandr-exl2' or repo_str == 'commandr-exl2-speculative': - prompt = await format_prompt_commandr(request.messages) - else: - prompt = await format_prompt(request.messages) + hf_get_messages = get_messages(request.messages) + prompt = hf_tokenizer.apply_chat_template(hf_get_messages, tokenize=False, add_generation_prompt=True) + print(prompt) if request.partial_generation is not None: prompt += request.partial_generation