Skip to content

Commit

Permalink
Merge pull request #246 from MeetKai/refactor-for-jinja
Browse files Browse the repository at this point in the history
Refactor codebase to use jinja
  • Loading branch information
jeffreymeetkai authored Sep 17, 2024
2 parents 2ade257 + d6cc922 commit ac42b58
Show file tree
Hide file tree
Showing 18 changed files with 713 additions and 593 deletions.
2 changes: 1 addition & 1 deletion functionary/prompt_template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, List

from functionary.prompt_template.base_template import SYSTEM_MESSAGE, PromptTemplate
from functionary.prompt_template.base_template import PromptTemplate
from functionary.prompt_template.llama3_prompt_template import Llama3Template
from functionary.prompt_template.llama3_prompt_template_v3 import Llama3TemplateV3
from functionary.prompt_template.llama31_prompt_template import Llama31Template
Expand Down
102 changes: 36 additions & 66 deletions functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import jinja2

from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils
from functionary.schema import generate_schema_from_functions

SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
PYTHON_RUN_SYS_MSG = "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."


class PromptTemplate:
_jinja_env = jinja2.Environment()
_jinja_env.policies["json.dumps_kwargs"] = {"sort_keys": False}
# Mapping from class --> instance to create singleton instance
_instances = {}

def __init__(self):
self._jinja_template = self._jinja_env.from_string(self.get_chat_template_jinja())

@abstractmethod
def get_start_of_function_call_token(self) -> str:
Expand Down Expand Up @@ -53,18 +57,6 @@ def get_additional_tokens(self) -> List[str]:
"""
raise NotImplementedError

@abstractmethod
def convert_message_to_prompt(self, message: Dict) -> str:
"""Return the prompt of this message
Args:
message (Dict): Dictionary of openAI format
Returns:
str: prompt of this message
"""
raise NotImplementedError

@abstractmethod
def get_stop_tokens_for_generation(self) -> List[str]:
"""Function to get list of stop tokens in generation
Expand Down Expand Up @@ -98,47 +90,12 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di
"""
return messages

def inject_system_messages_based_on_tools(
self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None
) -> List[Dict]:
"""This will be used to add Default system message, code-interpreter system message if needed
Args:
messages (List[Dict]): List of messages
tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None.
Returns:
List[Dict]: _description_
"""
messages_clone = messages.copy() # To avoid modifying the original list

functions = []
is_code_interpreter = False
if tools_or_functions is not None:
for item in tools_or_functions:
if (
"function" in item and item["function"] is not None
): # new data format: tools: [{"type": xx, "function": xxx}]
functions.append(item["function"])
elif "type" in item and item["type"] == "code_interpreter":
is_code_interpreter = True
else:
functions.append(item) # old format

messages_clone.insert(
0, {"role": "system", "content": generate_schema_from_functions(functions)}
)
if is_code_interpreter:
messages_clone.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG})
else:
messages_clone.insert(1, {"role": "system", "content": SYSTEM_MESSAGE})

return messages_clone

def get_prompt_from_messages(
self,
messages: List[Dict],
tools_or_functions: Optional[List[Dict]] = None,
bos_token: Optional[str] = "",
add_generation_prompt: bool = False,
) -> str:
"""This function is used to get the complete prompt for list of messages
Expand All @@ -149,14 +106,15 @@ def get_prompt_from_messages(
Returns:
str: the prompt for inference/training
"""
messages_clone = self.inject_system_messages_based_on_tools(
messages, tools_or_functions

prompt = self._jinja_template.render(
messages=messages,
tools=tools_or_functions,
bos_token=bos_token,
add_generation_prompt=add_generation_prompt,
)

full_text = ""
for message in messages_clone:
full_text += self.convert_message_to_prompt(message)
return full_text # Do not strip because llama3 uses: \n\n before content
return prompt

def get_end_token_to_token_id(self, tokenizer: Any) -> Dict[str, int]:
"""return a dictionary mapping from end_token --> token_id
Expand Down Expand Up @@ -360,11 +318,13 @@ def get_raw_response_from_assistant_message(
str: The mock raw response in str format
"""
# Form raw response from messages list
raw_response = self.convert_message_to_prompt(message)

# Remove null content
null_content = self.convert_message_to_prompt({"role": "assistant"})
raw_response = raw_response[len(null_content) :]
sys_msg = self.get_prompt_from_messages(
messages=[], tools_or_functions=[], add_generation_prompt=True
)
assistant_msg = self.get_prompt_from_messages(
messages=[message], tools_or_functions=[]
)
raw_response = assistant_msg[len(sys_msg) :]

# Remove stop tokens
for stop_token in self.get_stop_tokens_for_generation():
Expand All @@ -375,9 +335,19 @@ def get_raw_response_from_assistant_message(

return raw_response.rstrip()

def get_chat_template_jinja(self):
"""Return chat_template in jinja format"""
return "{# " + f"version={self.version}" + " #}"
def get_chat_template_jinja(self) -> str:
path_prefix = "./functionary/prompt_template/jinja_templates/"
with open(f"{path_prefix}json_to_ts_schema.txt", "r") as f:
json_to_ts_schema = f.read()
with open(f"{path_prefix}{self.version}.txt", "r") as f:
template = f.read()

return (
template[: template.index("{%")]
+ json_to_ts_schema
+ "\n"
+ template[template.index("{%") :]
)

def get_generation_prefix_for_tool_choice(self, tool_choice: Any):
if tool_choice == "auto" or tool_choice is None:
Expand Down
Loading

0 comments on commit ac42b58

Please sign in to comment.