diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py
index 98d7b1f..dc45c08 100644
--- a/functionary/prompt_template/__init__.py
+++ b/functionary/prompt_template/__init__.py
@@ -1,7 +1,7 @@
import re
from typing import Any, List
-from functionary.prompt_template.base_template import SYSTEM_MESSAGE, PromptTemplate
+from functionary.prompt_template.base_template import PromptTemplate
from functionary.prompt_template.llama3_prompt_template import Llama3Template
from functionary.prompt_template.llama3_prompt_template_v3 import Llama3TemplateV3
from functionary.prompt_template.llama31_prompt_template import Llama31Template
diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py
index 210b62c..da2548e 100644
--- a/functionary/prompt_template/base_template.py
+++ b/functionary/prompt_template/base_template.py
@@ -5,17 +5,21 @@
from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+import jinja2
+
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils
from functionary.schema import generate_schema_from_functions
-SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
-PYTHON_RUN_SYS_MSG = "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."
-
class PromptTemplate:
+ _jinja_env = jinja2.Environment()
+ _jinja_env.policies["json.dumps_kwargs"] = {"sort_keys": False}
# Mapping from class --> instance to create singleton instance
_instances = {}
+
+ def __init__(self):
+ self._jinja_template = self._jinja_env.from_string(self.get_chat_template_jinja())
@abstractmethod
def get_start_of_function_call_token(self) -> str:
@@ -53,18 +57,6 @@ def get_additional_tokens(self) -> List[str]:
"""
raise NotImplementedError
- @abstractmethod
- def convert_message_to_prompt(self, message: Dict) -> str:
- """Return the prompt of this message
-
- Args:
- message (Dict): Dictionary of openAI format
-
- Returns:
- str: prompt of this message
- """
- raise NotImplementedError
-
@abstractmethod
def get_stop_tokens_for_generation(self) -> List[str]:
"""Function to get list of stop tokens in generation
@@ -98,47 +90,12 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di
"""
return messages
- def inject_system_messages_based_on_tools(
- self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None
- ) -> List[Dict]:
- """This will be used to add Default system message, code-interpreter system message if needed
-
- Args:
- messages (List[Dict]): List of messages
- tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None.
-
- Returns:
- List[Dict]: _description_
- """
- messages_clone = messages.copy() # To avoid modifying the original list
-
- functions = []
- is_code_interpreter = False
- if tools_or_functions is not None:
- for item in tools_or_functions:
- if (
- "function" in item and item["function"] is not None
- ): # new data format: tools: [{"type": xx, "function": xxx}]
- functions.append(item["function"])
- elif "type" in item and item["type"] == "code_interpreter":
- is_code_interpreter = True
- else:
- functions.append(item) # old format
-
- messages_clone.insert(
- 0, {"role": "system", "content": generate_schema_from_functions(functions)}
- )
- if is_code_interpreter:
- messages_clone.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG})
- else:
- messages_clone.insert(1, {"role": "system", "content": SYSTEM_MESSAGE})
-
- return messages_clone
-
def get_prompt_from_messages(
self,
messages: List[Dict],
tools_or_functions: Optional[List[Dict]] = None,
+ bos_token: Optional[str] = "",
+ add_generation_prompt: bool = False,
) -> str:
"""This function is used to get the complete prompt for list of messages
@@ -149,14 +106,15 @@ def get_prompt_from_messages(
Returns:
str: the prompt for inference/training
"""
- messages_clone = self.inject_system_messages_based_on_tools(
- messages, tools_or_functions
+
+ prompt = self._jinja_template.render(
+ messages=messages,
+ tools=tools_or_functions,
+ bos_token=bos_token,
+ add_generation_prompt=add_generation_prompt,
)
- full_text = ""
- for message in messages_clone:
- full_text += self.convert_message_to_prompt(message)
- return full_text # Do not strip because llama3 uses: \n\n before content
+ return prompt
def get_end_token_to_token_id(self, tokenizer: Any) -> Dict[str, int]:
"""return a dictionary mapping from end_token --> token_id
@@ -360,11 +318,13 @@ def get_raw_response_from_assistant_message(
str: The mock raw response in str format
"""
# Form raw response from messages list
- raw_response = self.convert_message_to_prompt(message)
-
- # Remove null content
- null_content = self.convert_message_to_prompt({"role": "assistant"})
- raw_response = raw_response[len(null_content) :]
+ sys_msg = self.get_prompt_from_messages(
+ messages=[], tools_or_functions=[], add_generation_prompt=True
+ )
+ assistant_msg = self.get_prompt_from_messages(
+ messages=[message], tools_or_functions=[]
+ )
+ raw_response = assistant_msg[len(sys_msg) :]
# Remove stop tokens
for stop_token in self.get_stop_tokens_for_generation():
@@ -375,9 +335,19 @@ def get_raw_response_from_assistant_message(
return raw_response.rstrip()
- def get_chat_template_jinja(self):
- """Return chat_template in jinja format"""
- return "{# " + f"version={self.version}" + " #}"
+ def get_chat_template_jinja(self) -> str:
+ path_prefix = "./functionary/prompt_template/jinja_templates/"
+ with open(f"{path_prefix}json_to_ts_schema.txt", "r") as f:
+ json_to_ts_schema = f.read()
+ with open(f"{path_prefix}{self.version}.txt", "r") as f:
+ template = f.read()
+
+ return (
+ template[: template.index("{%")]
+ + json_to_ts_schema
+ + "\n"
+ + template[template.index("{%") :]
+ )
def get_generation_prefix_for_tool_choice(self, tool_choice: Any):
if tool_choice == "auto" or tool_choice is None:
diff --git a/functionary/prompt_template/jinja_templates/json_to_ts_schema.txt b/functionary/prompt_template/jinja_templates/json_to_ts_schema.txt
new file mode 100644
index 0000000..aab27db
--- /dev/null
+++ b/functionary/prompt_template/jinja_templates/json_to_ts_schema.txt
@@ -0,0 +1,261 @@
+{%- macro append_new_param_info(param_declaration, comment_info, examples_info, depth) -%}
+ {%- set offset = "" -%}
+ {%- if depth >= 1 -%}
+ {%- set offset = " " * depth -%}
+ {%- endif -%}
+ {%- if comment_info != "<|NONE|>" -%}
+ {{ "\n" + offset + comment_info }}
+ {%- if examples_info | length > 0 -%}
+ {# Append each example info #}
+ {%- for example in examples_info -%}
+ {{ "\n" + offset + "// " + example|string|replace("'", '"') }}
+ {%- endfor -%}
+ {%- endif -%}
+ {%- endif -%}
+ {{ "\n" + offset + param_declaration }}
+{%- endmacro -%}
+
+{%- macro convert_data_type(param_type) -%}
+ {%- if param_type == "integer" or param_type == "float" -%}
+ {{ "number" }}
+ {%- else -%}
+ {{ param_type }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro get_param_type(param) -%}
+ {%- set param_type = "any" -%}
+
+ {%- if "type" in param -%}
+ {%- set raw_param_type = param["type"] -%}
+ {%- if raw_param_type is iterable and raw_param_type is not string -%}
+ {%- set param_type = raw_param_type | join(" | ") -%}
+ {%- else -%}
+ {%- set param_type = raw_param_type -%}
+ {%- endif -%}
+ {{ convert_data_type(param_type) }}
+ {%- elif "oneOf" in param -%}
+ {%- set one_of_types = param["oneOf"]|selectattr("type", "defined")|list -%}
+ {%- set one_of_types = one_of_types|map(attribute="type")|unique|list -%}
+ {{ convert_data_type(one_of_types | join(" | ")) }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro get_format_param(param) -%}
+ {%- if "format" in param -%}
+ {{ param["format"] }}
+ {%- elif "oneOf" in param -%}
+ {%- set formats = [] -%}
+ {%- for item in param["oneOf"] -%}
+ {%- if "format" in item -%}
+ {%- if item["format"] == param["oneOf"][-1]["format"] -%}
+ {{ item["format"] }}
+ {%- else -%}
+ {{ item["format"] + " or "}}
+ {%- endif -%}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- else -%}
+ {{ "<|NONE|>" }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro get_param_info(param) -%}
+ {%- set param_type = param.get("type", "any") -%}
+ {%- set format_param = get_format_param(param) -%}
+
+ {%- if "description" in param or "default" in param or format_param != "<|NONE|>" or param["maximum"] or param["minimum"] or param["maxLength"] or param["minLength"] -%}
+ {{ "//" }}
+ {%- if "description" in param -%}
+ {%- set desc = param["description"] -%}
+ {%- if not desc.endswith(".") -%}
+ {%- set desc = desc + "." -%}
+ {%- endif -%}
+ {{ " " + desc }}
+ {%- endif -%}
+
+ {%- if "default" in param -%}
+ {%- set default_value = param["default"] -%}
+ {%- if param_type == "string" -%}
+ {%- set default_value = '"' ~ default_value ~ '"' -%}
+ {%- endif -%}
+ {{ " Default=" ~ default_value ~ "." }}
+ {%- endif -%}
+
+ {%- set format_param = get_format_param(param) -%}
+ {%- if format_param != "<|NONE|>" -%}
+ {{ " Format=" ~ format_param }}
+ {%- endif -%}
+
+ {%- for field, field_name in [("maximum", "Maximum"), ("minimum", "Minimum"), ("maxLength", "Maximum length"), ("minLength", "Minimum length")] -%}
+ {%- if field in param -%}
+ {{ " " + field_name ~ "=" ~ param[field] }}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- else -%}
+ {{ "<|NONE|>"}}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro get_enum_option_str(enum_options) -%}
+ {%- for v in enum_options -%}
+ {%- if v is string -%}
+ {{ '"' + v + '"' }}
+ {%- else -%}
+ {{ v }}
+ {%- endif -%}
+ {%- if enum_options|length > 0 and v != enum_options[-1] -%}
+ {{ " | " }}
+ {%- endif -%}
+ {%- endfor -%}
+{%- endmacro -%}
+
+{%- macro get_array_typescript(param_name, param_dic, depth) -%}
+ {%- set offset = '' -%}
+ {%- if depth >= 1 -%}
+ {%- set offset = " " * depth -%}
+ {%- endif -%}
+ {%- set items_info = param_dic.get('items', {}) -%}
+
+ {%- if items_info|length == 0 -%}
+ {%- if param_name -%}
+ {{ "\n" + offset + param_name + ": []" }}
+ {%- else -%}
+ {{ "\n" + offset + "[]" }}
+ {%- endif -%}
+ {%- else -%}
+ {%- set array_type = get_param_type(items_info) -%}
+ {%- if array_type == 'object' -%}
+ {%- if param_name -%}
+ {{ "\n" + offset + param_name + ": {" }}
+ {%- else -%}
+ {{ "\n" + offset + "{" }}
+ {%- endif -%}
+ {{ get_parameter_typescript(items_info.get('properties', {}), items_info.get('required', []), depth + 1) -}}
+ {{- "\n" + offset + "}[]" }}
+ {%- elif array_type == 'array' -%}
+ {%- set item_info = get_array_typescript(None, items_info, depth + 1) -%}
+ {%- if not param_name -%}
+ {{ "\n" + item_info + "[]" }}
+ {%- else -%}
+ {{ "\n" + offset + param_name + ": " + item_info|trim + "[]" }}
+ {%- endif -%}
+ {%- else -%}
+ {%- if 'enum' in items_info -%}
+ {%- set item_type = get_enum_option_str(items_info['enum']) -%}
+ {%- if param_name is none -%}
+ {{ "(" + item_type + ")[]"}}
+ {%- else -%}
+ {{ "\n" + offset + param_name + ": (" + item_type + ")[]" }}
+ {%- endif -%}
+ {%- else -%}
+ {%- if param_name is none -%}
+ {{ "\n" + array_type + "[]" }}
+ {%- else -%}
+ {{ "\n" + offset + param_name + ": " + array_type + "[]," }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- endif -%}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro get_parameter_typescript(properties, required_params, depth=0) -%}
+ {%- set res = "" -%}
+ {%- for param_name, param in properties.items() -%}
+ {%- if param is mapping -%}
+ {%- set comment_info = get_param_info(param) -%}
+ {# Param Examples #}
+ {%- set examples_info = [] -%}
+ {%- if "examples" in param -%}
+ {%- set examples_info = ["Example " + param_name + ":"] -%}
+ {%- set examples_info = examples_info + param["examples"] -%}
+ {%- endif -%}
+
+ {# Param Name declaration #}
+ {%- set param_declaration = param_name -%}
+ {%- if required_params is iterable and param_name not in required_params -%}
+ {%- set param_declaration = param_declaration + "?" -%}
+ {%- endif -%}
+
+ {%- set param_type = get_param_type(param) -%}
+
+ {# Handle indentation based on depth #}
+ {%- set offset = "" -%}
+ {%- if depth >= 1 -%}
+ {%- set offset = " " * depth -%}
+ {%- endif -%}
+
+ {%- if param_type == "object" -%}
+ {%- if comment_info != "<|NONE|>" -%}
+ {{ "\n" + offset + comment_info }}
+ {%- endif -%}
+ {%- if examples_info|length > 0 -%}
+ {%- for example in examples_info -%}
+ {{ "\n" + offset + "// " + example|string|replace("'", '"') }}
+ {%- endfor -%}
+ {%- endif -%}
+ {%- set param_declaration = param_declaration + ": {" -%}
+ {{ "\n" + offset + param_declaration -}}
+ {{- get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1) -}}
+ {{- "\n" + offset + "}," }}
+ {%- elif param_type == "array" -%}
+ {%- set item_info = param.get("items", {}) -%}
+ {%- if "type" not in item_info -%}
+ {%- set param_declaration = param_declaration + ": []," -%}
+ {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }}
+ {%- else -%}
+ {%- if comment_info != "<|NONE|>" -%}
+ {{ "\n" + offset + comment_info }}
+ {%- endif -%}
+ {%- if examples_info|length > 0 -%}
+ {%- for example in examples_info -%}
+ {{ "\n" + offset + "// " + example|string|replace("'", '"') }}
+ {%- endfor -%}
+ {%- endif -%}
+ {%- set array_declaration = get_array_typescript(param_declaration, param, depth) -%}
+ {%- if not array_declaration.endswith(",") -%}
+ {%- set array_declaration = array_declaration + "," -%}
+ {%- endif -%}
+ {{ array_declaration}}
+ {%- endif -%}
+ {%- else -%}
+ {%- if "enum" in param -%}
+ {%- set param_type = get_enum_option_str(param["enum"]) -%}
+ {%- endif -%}
+ {%- if "nullable" in param and param["nullable"] -%}
+ {%- set param_type = param_type + " | null" -%}
+ {%- endif -%}
+ {%- set param_declaration = param_declaration + ": " + param_type + "," -%}
+ {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- endfor -%}
+{%- endmacro -%}
+
+{%- macro generate_schema_from_functions(functions, namespace='functions') -%}
+ {{ "// Supported function definitions that should be called when necessary.\n" -}}
+ {{- "namespace " + namespace + " {\n\n" -}}
+
+ {%- for function in functions -%}
+ {%- if function.get("function") -%}
+ {%- set function = function.get("function") -%}
+ {%- endif -%}
+
+ {%- set function_name = function.get("name") -%}
+ {%- if function_name -%}
+ {%- set description = function.get('description', '') -%}
+ {%- set parameters = function.get('parameters', {}) -%}
+ {{- "// " + description + "\n" -}}
+ {{- "type " + function_name -}}
+ {%- if parameters and parameters.get("properties") -%}
+ {{- " = (_: {" -}}
+ {%- set required_params = parameters.get("required", []) -%}
+ {{ get_parameter_typescript(parameters.get("properties"), required_params, 0) -}}
+ {{- "\n}) => any;\n\n" }}
+ {%- else -%}
+ {{ " = () => any;\n\n" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- endfor -%}
+ {{ "} // namespace " + namespace }}
+{%- endmacro -%}
\ No newline at end of file
diff --git a/functionary/prompt_template/jinja_templates/v2.llama3.txt b/functionary/prompt_template/jinja_templates/v2.llama3.txt
new file mode 100644
index 0000000..4796a9a
--- /dev/null
+++ b/functionary/prompt_template/jinja_templates/v2.llama3.txt
@@ -0,0 +1,28 @@
+{# version=v2.llama3 #}{%- if not tools -%}
+ {%- set tools = [] -%}
+{%- endif -%}
+{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}}
+{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%}
+ {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }}
+{%- else -%}
+ {{ "<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary<|eot_id|>" }}
+{%- endif -%}
+{%- for message in messages -%}
+ {%- if message['role'] == 'user' or message['role'] == 'system' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- elif message['role'] == 'tool' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + 'name=' + message['name'] + '\n' + message['content'] + '<|eot_id|>' }}
+ {%- else -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
+ {%- if message['content'] -%}
+ {{ message['content'] }}
+ {%- endif -%}
+ {%- if 'tool_calls' in message and message['tool_calls'] -%}
+ {%- for tool_call in message['tool_calls'] -%}
+ {{ '<|reserved_special_token_249|>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
+ {%- endfor -%}
+ {%- endif -%}
+ {{ '<|eot_id|>' }}
+ {%- endif -%}
+{%- endfor -%}
+{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}
\ No newline at end of file
diff --git a/functionary/prompt_template/jinja_templates/v2.txt b/functionary/prompt_template/jinja_templates/v2.txt
new file mode 100644
index 0000000..ff7dc86
--- /dev/null
+++ b/functionary/prompt_template/jinja_templates/v2.txt
@@ -0,0 +1,27 @@
+{# version=v2 #}{%- if not tools -%}
+ {%- set tools = [] -%}
+{%- endif -%}
+{{ bos_token + '<|from|>system\n<|recipient|>all\n<|content|>' + generate_schema_from_functions(tools) -}}
+{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%}
+ {{ '\n<|from|>system\n<|recipient|>all\n<|content|>When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.' }}
+{%- else -%}
+ {{ "\n<|from|>system\n<|recipient|>all\n<|content|>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary" }}
+{%- endif -%}
+{%- for message in messages -%}
+ {%- if message['role'] == 'user' or message['role'] == 'system' -%}
+ {{ '\n<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] }}
+ {%- elif message['role'] == 'tool' -%}
+ {{ '\n<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] }}
+ {%- else -%}
+ {%- if message['content'] -%}
+ {{ "\n<|from|>" + message['role'] + "\n<|recipient|>all\n<|content|>" + message['content'] }}
+ {%- endif -%}
+ {%- if 'tool_calls' in message and message['tool_calls'] -%}
+ {%- for tool_call in message['tool_calls'] -%}
+ {{ '\n<|from|>' + message['role'] + '\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] }}
+ {%- endfor -%}
+ {%- endif -%}
+ {{ "<|stop|>" }}
+ {%- endif -%}
+{%- endfor -%}
+{% if add_generation_prompt %}{{ '\n<|from|>assistant\n<|recipient|>' }}{% endif %}
\ No newline at end of file
diff --git a/functionary/prompt_template/jinja_templates/v3-llama3.1.txt b/functionary/prompt_template/jinja_templates/v3-llama3.1.txt
new file mode 100644
index 0000000..29d64a2
--- /dev/null
+++ b/functionary/prompt_template/jinja_templates/v3-llama3.1.txt
@@ -0,0 +1,58 @@
+{# version=v3-llama3.1 #}{%- if not tools is defined -%}
+ {%- set tools = none -%}
+{%- endif -%}
+
+{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%}
+{%- if has_code_interpreter -%}
+ {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%}
+{%- endif -%}
+
+{#- System message + builtin tools #}
+{{- bos_token + "<|start_header_id|>system<|end_header_id|>\n\n" }}
+{%- if has_code_interpreter %}
+ {{- "Environment: ipython\n\n" }}
+{%- else -%}
+ {{ "\n"}}
+{%- endif %}
+{{- "Cutting Knowledge Date: December 2023\n\n" }}
+{%- if tools %}
+ {{- "\nYou have access to the following functions:\n\n" }}
+ {%- for t in tools %}
+ {%- if "type" in t -%}
+ {{ "Use the function '"|safe + t["function"]["name"] + "' to '"|safe + t["function"]["description"] + "'\n"|safe + t["function"] | tojson() }}
+ {%- else -%}
+ {{ "Use the function '"|safe + t["name"] + "' to '"|safe + t["description"] + "'\n"|safe + t | tojson() }}
+ {%- endif -%}
+ {{- "\n\n" }}
+ {%- endfor %}
+ {{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with \n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}}
+{%- endif %}
+{{- "<|eot_id|>" -}}
+
+{%- for message in messages -%}
+ {%- if message['role'] == 'user' or message['role'] == 'system' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- elif message['role'] == 'tool' -%}
+ {{ '<|start_header_id|>ipython<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- else -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
+ {%- if message['content'] -%}
+ {{ message['content'] }}
+ {%- endif -%}
+ {%- if 'tool_calls' in message and message['tool_calls'] -%}
+ {%- for tool_call in message['tool_calls'] -%}
+ {%- if tool_call["function"]["name"] == "python" -%}
+ {{ '<|python_tag|>' + tool_call['function']['arguments'] }}
+ {%- else -%}
+ {{ '' + tool_call['function']['arguments'] + '' }}
+ {%- endif -%}
+ {%- endfor -%}
+ {{ '<|eom_id|>' }}
+ {%- else -%}
+ {{ '<|eot_id|>' }}
+ {%- endif -%}
+ {%- endif -%}
+{%- endfor -%}
+{%- if add_generation_prompt -%}
+ {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
+{%- endif -%}
\ No newline at end of file
diff --git a/functionary/prompt_template/jinja_templates/v3.llama3.txt b/functionary/prompt_template/jinja_templates/v3.llama3.txt
new file mode 100644
index 0000000..5a1d8a5
--- /dev/null
+++ b/functionary/prompt_template/jinja_templates/v3.llama3.txt
@@ -0,0 +1,26 @@
+{# version=v3.llama3 #}{%- if not tools -%}
+ {%- set tools = [] -%}
+{%- endif -%}
+{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}}
+{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%}
+ {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }}
+{%- endif -%}
+{%- for message in messages -%}
+ {%- if message['role'] == 'user' or message['role'] == 'system' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- elif message['role'] == 'tool' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- else -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
+ {%- if message['content'] -%}
+ {{ '>>>all\n' + message['content'] }}
+ {%- endif -%}
+ {%- if 'tool_calls' in message and message['tool_calls'] -%}
+ {%- for tool_call in message['tool_calls'] -%}
+ {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
+ {%- endfor -%}
+ {%- endif -%}
+ {{ '<|eot_id|>' }}
+ {%- endif -%}
+{%- endfor -%}
+{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %}
\ No newline at end of file
diff --git a/functionary/prompt_template/jinja_templates/v3.llava_llama.txt b/functionary/prompt_template/jinja_templates/v3.llava_llama.txt
new file mode 100644
index 0000000..c3abf5a
--- /dev/null
+++ b/functionary/prompt_template/jinja_templates/v3.llava_llama.txt
@@ -0,0 +1,40 @@
+{# version=v3.llava_llama #}{%- if not tools -%}
+ {%- set tools = [] -%}
+{%- endif -%}
+{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}}
+{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%}
+ {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }}
+{%- endif -%}
+{%- for message in messages -%}
+ {%- if message['role'] == 'system' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- elif message['role'] == 'user' -%}
+ {%- if message['content'] is string -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] }}
+ {%- else -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
+ {%- for content in message['content'] -%}
+ {%- if content['type'] == 'text' -%}
+ {{ content['text'] }}
+ {%- else -%}
+ {{ '<|reserved_special_token_250|>' }}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- endif -%}
+ {{ '<|eot_id|>' }}
+ {%- elif message['role'] == 'tool' -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
+ {%- else -%}
+ {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
+ {%- if message['content'] -%}
+ {{ '>>>all\n' + message['content'] }}
+ {%- endif -%}
+ {%- if 'tool_calls' in message and message['tool_calls'] -%}
+ {%- for tool_call in message['tool_calls'] -%}
+ {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
+ {%- endfor -%}
+ {%- endif -%}
+ {{ '<|eot_id|>' }}
+ {%- endif -%}
+{%- endfor -%}
+{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %}
\ No newline at end of file
diff --git a/functionary/prompt_template/llama31_prompt_template.py b/functionary/prompt_template/llama31_prompt_template.py
index cff00a2..c237ce1 100644
--- a/functionary/prompt_template/llama31_prompt_template.py
+++ b/functionary/prompt_template/llama31_prompt_template.py
@@ -8,69 +8,6 @@
from functionary.prompt_template.base_template import PromptTemplate
-def get_system_prompt_for_custom_tools(custom_tools: List) -> str:
- custom_tool_params = ""
- for t in custom_tools:
- custom_tool_params += get_instruction_string(t) + "\n"
- custom_tool_params += get_parameters_string(t) + "\n\n"
-
- content = f"""
-You have access to the following functions:
-
-{custom_tool_params}
-Think very carefully before calling functions.
-If a you choose to call a function ONLY reply in the following format:
-<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
-where
-
-start_tag => ` a JSON dict with the function argument name as key and function argument value as value.
-end_tag => ``
-
-Here is an example,
-{{"example_name": "example_value"}}
-
-Reminder:
-- If looking for real time information use relevant functions before falling back to brave_search
-- Function calls MUST follow the specified format, start with
-- Required parameters MUST be specified
-- Only call one function at a time
-- Put the entire function call reply on one line
-
-"""
- return content
-
-
-def get_instruction_string(custom_tool_definition) -> str:
- name, description = (
- custom_tool_definition["name"],
- custom_tool_definition["description"],
- )
- return f"Use the function '{name}' to '{description}'"
-
-
-def get_parameters_string(custom_tool_definition) -> str:
- return json.dumps(custom_tool_definition)
-
-
-def get_system_message_for_tools(tools: List[Dict], use_code_interpreter) -> List[Dict]:
- content = ""
- if use_code_interpreter:
- content += "Environment: ipython\n"
-
- current_date = datetime.datetime.now()
- formatted_date = current_date.strftime("%d %B %Y")
- date_str = f"""
-Cutting Knowledge Date: December 2023\n\n"""
- content += date_str
-
- if tools:
- custom_message = get_system_prompt_for_custom_tools(tools)
- content += custom_message
-
- return {"role": "system", "content": content}
-
-
def parse_function_call_from_text(function_call_text: str) -> Optional[Dict]:
index = function_call_text.find(">")
if index >= 0:
@@ -109,87 +46,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di
def get_stop_tokens_for_generation(self) -> List[str]:
return [self.eos_token, "<|end_of_text|>", self.eof_message]
- def inject_system_messages_based_on_tools(
- self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None
- ) -> List[Dict]:
- """This will be used to add Default system message, code-interpreter system message if needed
-
- Args:
- messages (List[Dict]): List of messages
- tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None.
-
- Returns:
- List[Dict]: _description_
- """
- messages_clone = messages.copy() # To avoid modifying the original list
-
- functions = []
- is_code_interpreter = False
- if tools_or_functions is not None:
- for item in tools_or_functions:
- if (
- "function" in item and item["function"] is not None
- ): # new data format: tools: [{"type": xx, "function": xxx}]
- functions.append(item["function"])
- elif "type" in item and item["type"] == "code_interpreter":
- is_code_interpreter = True
- else:
- functions.append(item) # old format
-
- tools_system_message = get_system_message_for_tools(
- functions, is_code_interpreter
- )
- messages_clone.insert(0, tools_system_message)
-
- return messages_clone
-
- def convert_message_to_prompt(self, message: Dict) -> str:
- role = message["role"]
- if role == "tool":
- role = "ipython"
- content = message.get("content", None)
-
- prompt_template = (
- f"{self.start_header}{role}{self.end_header}\n\n" + "{text}{eot_content}"
- )
- eot_content = self.eos_token
-
- if role in ["user", "system", "ipython"]:
- return prompt_template.format(text=content, eot_content=eot_content)
-
- assert role == "assistant", f"role must be assistant, but: {role}"
-
- # set content=none if content=""
- if type(content) is str and len(content) == 0:
- content = None
-
- tool_calls = message.get("tool_calls", [])
- if tool_calls is None:
- tool_calls = []
-
- if content is None and len(tool_calls) == 0: # inference time
- return f"{self.start_header}{role}{self.end_header}\n\n"
-
- total_content = content if content else ""
-
- # list of text representing function calls: {function_name}\n{arguments}
- tool_call_prompts = []
- for tool_call in tool_calls:
- arguments = tool_call["function"]["arguments"]
- assert isinstance(arguments, str)
- tool_name = tool_call["function"]["name"]
- if tool_name == "python":
- tool_prompt = f"<|python_tag|>{arguments}"
- else:
- tool_prompt = f"{arguments}"
- tool_call_prompts.append(tool_prompt)
-
- # join all function calls
- if tool_call_prompts:
- total_content += "".join(tool_call_prompts)
- eot_content = self.eof_message
- return prompt_template.format(text=total_content, eot_content=eot_content)
-
def parse_assistant_response(
self, llm_output: str, tool_choice: Any = None
) -> Dict:
diff --git a/functionary/prompt_template/llama3_prompt_template.py b/functionary/prompt_template/llama3_prompt_template.py
index 5042acb..ef882c8 100644
--- a/functionary/prompt_template/llama3_prompt_template.py
+++ b/functionary/prompt_template/llama3_prompt_template.py
@@ -8,22 +8,6 @@
from functionary.prompt_template.base_template import PromptTemplate
-def convert_to_llama3_messages(messages: List[Dict]) -> List[Dict]:
- result = []
- index = 0
- while index < len(messages):
- if messages[index]["role"] in ["user", "system"]:
- result.append(messages[index])
- index += 1
- else:
- if messages[index]["role"] == "assistant":
- tool_calls = messages[index].get("tool_calls", [])
- if len(tool_calls) == 0:
- result.append(messages[index])
- else:
- messages
-
-
class Llama3Template(PromptTemplate):
function_separator = "<|reserved_special_token_249|>"
version = "v2.llama3"
@@ -172,58 +156,6 @@ def parse_assistant_response(
tool_calls = None if len(tool_calls) == 0 else tool_calls
return {"role": "assistant", "content": text_content, "tool_calls": tool_calls}
- def convert_message_to_prompt(self, message: Dict) -> str:
- role = message["role"]
- content = message.get("content", None)
- if role == "tool":
- tool_name = message["name"]
- content = f"name={tool_name}\n{content}"
-
- prompt_template = (
- "<|start_header_id|>%s<|end_header_id|>\n\n{text}<|eot_id|>" % role
- )
-
- if role in ["user", "system", "tool"]:
- return prompt_template.format(text=content)
-
- assert role == "assistant"
- # set content=none if content=""
- if type(content) is str and len(content) == 0:
- content = None
-
- tool_calls = message.get("tool_calls", [])
- if tool_calls is None:
- tool_calls = []
-
- if content is None and len(tool_calls) == 0:
- return f"<|start_header_id|>{role}<|end_header_id|>\n\n"
-
- if content is not None and len(tool_calls) == 0: # text-only
- return prompt_template.format(text=content)
-
- tool_call_prompts = []
- for tool_call in tool_calls:
- arguments = tool_call["function"]["arguments"]
- tool_name = tool_call["function"]["name"]
- tool_prompt = f"{tool_name}\n{arguments}"
- tool_call_prompts.append(tool_prompt)
-
- if (
- content is None and len(tool_calls) > 0
- ): # function call only (fc1fc2)
- tool_call_content = self.function_separator + self.function_separator.join(
- tool_call_prompts
- )
- return prompt_template.format(text=tool_call_content)
-
- # Here is the case contains both text-response and tool_calls (contentfc1fc2)
- total_content = (
- content
- + self.function_separator
- + self.function_separator.join(tool_call_prompts)
- )
- return prompt_template.format(text=total_content)
-
def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]:
"""Order the tool results by the order of tool call ids
@@ -235,18 +167,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di
"""
return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages)
- def update_state_for_function(self, current_state):
- """update the state when a function is going to be called
-
- Args:
- current_state (_type_): _description_
- """
- current_state["response_type"] = "function"
- current_state["skip_until_reach"] = "\n"
- current_state["current_text"] = ""
- current_state["func_index"] += 1
- current_state["call_id"] = prompt_utils.get_random_tool_call_id()
-
def initialize_fsm_gen_state(
self,
tool_choice: Union[str, Tool],
@@ -441,32 +361,6 @@ def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
return options
- def get_chat_template_jinja(self) -> str:
- chat_template = """{% for message in messages %}
- {% if message['role'] == 'user' or message['role'] == 'system' %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% elif message['role'] == 'tool' %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + 'name=' + message['name'] + '\n' + message['content'] + '<|eot_id|>' }}
- {% else %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
- {% if message['content'] is not none %}
- {{ message['content'] }}
- {% endif %}
- {% if 'tool_calls' in message and message['tool_calls'] is not none %}
- {% for tool_call in message['tool_calls'] %}
- {{ '<|reserved_special_token_249|>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
- {% endfor %}
- {% endif %}
- {{ '<|eot_id|>' }}
- {% endif %}
- {% endfor %}
- {% if add_generation_prompt %}{{ '<|start_header_id|>{role}<|end_header_id|>\n\n' }}{% endif %}
- """
- chat_template = chat_template.replace(" ", "")
- chat_template = chat_template.replace("
\n", "")
- chat_template = chat_template.strip()
- return chat_template
-
def get_force_text_generation_prefix(self):
return ""
diff --git a/functionary/prompt_template/llama3_prompt_template_v3.py b/functionary/prompt_template/llama3_prompt_template_v3.py
index 91af993..a5793a7 100644
--- a/functionary/prompt_template/llama3_prompt_template_v3.py
+++ b/functionary/prompt_template/llama3_prompt_template_v3.py
@@ -3,18 +3,7 @@
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils
-from functionary.prompt_template.base_template import PYTHON_RUN_SYS_MSG, PromptTemplate
-from functionary.schema import generate_schema_from_functions
-
-SYSTEM_CONTENT = """You are capable of executing available function(s) if required.
-Only execute function(s) when absolutely necessary.
-Ask for the required input to:recipient==all
-Use JSON for function arguments.
-Respond in this format:
->>>${recipient}
-${content}
-Available functions:
-"""
+from functionary.prompt_template.base_template import PromptTemplate
class Llama3TemplateV3(PromptTemplate):
@@ -147,56 +136,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di
"""
return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages)
- def convert_message_to_prompt(self, message: Dict) -> str:
- role = message["role"]
- content = message.get("content", None)
-
- # comment this as currently the Llama-70b was trained using this
- # if role == "tool":
- # tool_name = message["name"]
- # content = f"name={tool_name}\n{content}"
-
- prompt_template = (
- f"{self.start_header}{role}{self.end_header}\n\n"
- + "{text}"
- + self.eos_token
- )
-
- if role in ["user", "system", "tool"]:
- return prompt_template.format(text=content)
-
- assert role == "assistant", f"role must be assistant, but: {role}"
-
- # set content=none if content=""
- if type(content) is str and len(content) == 0:
- content = None
-
- tool_calls = message.get("tool_calls", [])
- if tool_calls is None:
- tool_calls = []
-
- if content is None and len(tool_calls) == 0: # inference time
- return f"{self.start_header}{role}{self.end_header}\n\n{self.function_separator}"
-
- if content is not None: # text-only
- tool_calls = [
- {"function": {"name": "all", "arguments": content}}
- ] + tool_calls
-
- # list of text representing function calls: {function_name}\n{arguments}
- tool_call_prompts = []
- for tool_call in tool_calls:
- arguments = tool_call["function"]["arguments"]
- tool_name = tool_call["function"]["name"]
- tool_prompt = f"{tool_name}\n{arguments}"
- tool_call_prompts.append(tool_prompt)
-
- # join all function calls
- total_content = self.function_separator + self.function_separator.join(
- tool_call_prompts
- )
- return prompt_template.format(text=total_content)
-
def parse_assistant_response(
self, llm_output: str, tool_choice: Any = None
) -> Dict:
@@ -235,45 +174,6 @@ def parse_assistant_response(
return {"role": "assistant", "content": text_content, "tool_calls": tool_calls}
- def inject_system_messages_based_on_tools(
- self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None
- ) -> List[Dict]:
- """This will be used to add Default system message, code-interpreter system message if needed
-
- Args:
- messages (List[Dict]): List of messages
- tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None.
-
- Returns:
- List[Dict]: _description_
- """
- messages_clone = messages.copy() # To avoid modifying the original list
-
- functions = []
- is_code_interpreter = False
- if tools_or_functions is not None:
- for item in tools_or_functions:
- if (
- "function" in item and item["function"] is not None
- ): # new data format: tools: [{"type": xx, "function": xxx}]
- functions.append(item["function"])
- elif "type" in item and item["type"] == "code_interpreter":
- is_code_interpreter = True
- else:
- functions.append(item) # old format
-
- messages_clone.insert(
- 0,
- {
- "role": "system",
- "content": SYSTEM_CONTENT + generate_schema_from_functions(functions),
- },
- )
- if is_code_interpreter:
- messages_clone.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG})
-
- return messages_clone
-
def get_force_text_generation_prefix(self):
return f"all\n"
@@ -478,29 +378,3 @@ def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
options = [self.fn_param_sep_token]
return options
-
- def get_chat_template_jinja(self) -> str:
- chat_template = """{% for message in messages %}
- {% if message['role'] == 'user' or message['role'] == 'system' %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% elif message['role'] == 'tool' %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% else %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
- {% if message['content'] is not none %}
- {{ '>>>all\n' + message['content'] }}
- {% endif %}
- {% if 'tool_calls' in message and message['tool_calls'] is not none %}
- {% for tool_call in message['tool_calls'] %}
- {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
- {% endfor %}
- {% endif %}
- {{ '<|eot_id|>' }}
- {% endif %}
- {% endfor %}
- {% if add_generation_prompt %}{{ '<|start_header_id|>{role}<|end_header_id|>\n\n' }}{% endif %}
- """
- chat_template = chat_template.replace(" ", "")
- chat_template = chat_template.replace("
\n", "")
- chat_template = chat_template.strip()
- return chat_template
diff --git a/functionary/prompt_template/llava_prompt_template.py b/functionary/prompt_template/llava_prompt_template.py
index 1a3b64a..bc9f7e2 100644
--- a/functionary/prompt_template/llava_prompt_template.py
+++ b/functionary/prompt_template/llava_prompt_template.py
@@ -9,52 +9,3 @@ class LlavaLlama(Llama3TemplateV3):
version = "v3.llava_llama"
# This token will be replaced with image_token_id (-200) after we tokenize the text
image_token = "<|reserved_special_token_250|>"
-
- def convert_message_to_prompt(self, message: Dict) -> str:
- role = message["role"]
- content = message.get("content", None)
-
- # handle the case when user uploads images (content is a list)
- if role == "user" and type(content) is list:
- text_content = prompt_utils.stringify_content_with_images(
- message["content"], self.image_token
- )
- return f"{self.start_header}{role}{self.end_header}\n\n{text_content}{self.eos_token}"
- return super().convert_message_to_prompt(message)
-
- def get_chat_template_jinja(self) -> str:
- chat_template = """{% for message in messages %}
- {% if message['role'] == 'user'%}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
- {% if message['content'] is iterable and (message['content'] is not string and message['content'] is not mapping) %}
- {% for content_item in message['content'] %}
- {% if content_item['type'] == 'image_url' %}
- {{ '<|reserved_special_token_250|>' }}
- {% else %}
- {{ content_item['text'] }}
- {% endif %}
- {% endfor %}
- {% else %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% endif %}
- {% elif message['role'] == 'tool' or message['role'] == 'system' %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% else %}
- {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
- {% if message['content'] is not none %}
- {{ '>>>all\n' + message['content'] }}
- {% endif %}
- {% if 'tool_calls' in message and message['tool_calls'] is not none %}
- {% for tool_call in message['tool_calls'] %}
- {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
- {% endfor %}
- {% endif %}
- {{ '<|eot_id|>' }}
- {% endif %}
- {% endfor %}
- {% if add_generation_prompt %}{{ '<|start_header_id|>{role}<|end_header_id|>\n\n' }}{% endif %}
- """
- chat_template = chat_template.replace(" ", "")
- chat_template = chat_template.replace("
\n", "")
- chat_template = chat_template.strip()
- return chat_template
diff --git a/functionary/prompt_template/prompt_template_v2.py b/functionary/prompt_template/prompt_template_v2.py
index 66731cf..64c7b93 100644
--- a/functionary/prompt_template/prompt_template_v2.py
+++ b/functionary/prompt_template/prompt_template_v2.py
@@ -121,50 +121,6 @@ def get_additional_tokens(self) -> List[str]:
self.stop_token,
]
- def convert_message_to_prompt(self, message: Dict) -> str:
- role = message["role"]
- content = message.get("content", None)
-
- if role in [
- "system",
- "user",
- ]: # <|from|>system\n<|recipient|>all\n<|content|>xxx
- return f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}\n"
-
- if role == "tool": # <|from|>tool_name\n<|recipient|>all\n<|content|>xxx
- tool_name = message["name"]
- return f"{self.from_token}{tool_name}\n{self.recipient_token}all\n{self.content_token}{content}\n"
-
- assert role == "assistant"
-
- # set content=none if content=""
- if type(content) is str and len(content) == 0:
- content = None
-
- tool_calls = message.get("tool_calls", [])
- if tool_calls is None:
- tool_calls = []
- if (
- len(tool_calls) == 0 and content is None
- ): # for inference: <|from|> assistant\n<|recipient|>
- return f"{self.from_token}{role}\n{self.recipient_token}"
-
- if len(tool_calls) == 0: # <|from|>assistant\n<|recipient|>all\n<|content|>xxx
- return f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}{self.stop_token}\n"
-
- result = ""
- if content is not None: # both text-response and function_call
- result += f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}\n"
-
- for tool in tool_calls:
- func_name = tool["function"]["name"]
- arguments = tool["function"]["arguments"]
- # <|from|>assistant\n<|recipient|>func_name\n<|content|>xxxx
- result += f"{self.from_token}{role}\n{self.recipient_token}{func_name}\n{self.content_token}{arguments}\n"
-
- result = result.strip() + f"{self.stop_token}\n"
- return result
-
def get_stop_tokens_for_generation(self) -> List[str]:
return [self.stop_token]
@@ -252,38 +208,6 @@ def get_recipient(self, current_text: str) -> str:
end_index = current_text.find(f"\n{self.content_token}")
return current_text[start_index:end_index].strip()
- def get_chat_template_jinja(self) -> str:
- chat_template = """{% for message in messages %}
- {% if message['role'] == 'user' or message['role'] == 'system' %}
- {{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}
- {% elif message['role'] == 'tool' %}
- {{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}
- {% else %}
- {% set contain_content='no'%}
- {% if message['content'] is not none %}
- {{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}
- {% set contain_content='yes'%}
- {% endif %}
- {% if 'tool_calls' in message and message['tool_calls'] is not none %}
- {% for tool_call in message['tool_calls'] %}
- {% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}
- {% if loop.index == 1 and contain_content == "no" %}
- {{ prompt }}
- {% else %}
- {{ '\n' + prompt}}
- {% endif %}
- {% endfor %}
- {% endif %}
- {{ '<|stop|>\n' }}
- {% endif %}
- {% endfor %}
- {% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %}
- """
- chat_template = chat_template.replace(" ", "")
- chat_template = chat_template.replace("
\n", "")
- chat_template = chat_template.strip()
- return chat_template
-
def initialize_fsm_gen_state(
self,
tool_choice: Union[str, Tool],
diff --git a/functionary/prompt_template/prompt_utils.py b/functionary/prompt_template/prompt_utils.py
index 7ce1a79..600ff98 100644
--- a/functionary/prompt_template/prompt_utils.py
+++ b/functionary/prompt_template/prompt_utils.py
@@ -1,12 +1,13 @@
+import base64
+import os
import random
import string
-from typing import Dict, List, Optional, Union
-from PIL import Image
from io import BytesIO
-import os
-import base64
+from typing import Dict, List, Optional, Union
+
import requests
import torch
+from PIL import Image
from transformers import LlamaTokenizer
from functionary.openai_types import ChatMessage, Function, Tool
@@ -80,14 +81,16 @@ def prepare_messages_for_inference(
prompt_template = get_prompt_template_from_tokenizer(tokenizer)
dic_messages = [mess.dict() for mess in messages]
- dic_messages.append({"role": "assistant"})
dic_messages = prompt_template.pre_process_messages_before_inference(dic_messages)
# This also checks for code_interpreter and adds python default system message instead
# default system message
final_prompt = prompt_template.get_prompt_from_messages(
- dic_messages, tools_or_functions=tools_or_functions
+ dic_messages,
+ tools_or_functions=tools_or_functions,
+ bos_token="",
+ add_generation_prompt=True,
)
# add prefix based on tool-choice
diff --git a/requirements.txt b/requirements.txt
index 464789e..4164362 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,3 +12,4 @@ protobuf==3.20.0
tokenizers==0.19.1
vllm==0.5.4; sys_platform != "darwin"
json_source_map==1.0.5
+jinja2==3.1.4
diff --git a/tests/prompt_test_v3.llava_llama.txt b/tests/prompt_test_v3.llava_llama.txt
index 75e78ed..54b85ce 100644
--- a/tests/prompt_test_v3.llava_llama.txt
+++ b/tests/prompt_test_v3.llava_llama.txt
@@ -32,7 +32,7 @@ who is the CEO of Meetkai<|eot_id|><|start_header_id|>assistant<|end_header_id|>
>>>all
James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|eot_id|><|start_header_id|>user<|end_header_id|>
-is the car Song more expensive than car Tang?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+is the car Song more expensive than car Tang?<|reserved_special_token_250|><|reserved_special_token_250|><|eot_id|><|start_header_id|>assistant<|end_header_id|>
>>>all
I will get the price of 2 cars and compare>>>get_car_price
@@ -46,8 +46,7 @@ I will get the price of 2 cars and compare>>>get_car_price
>>>all
No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|eot_id|><|start_header_id|>user<|end_header_id|>
-<|reserved_special_token_250|>
-what's the weather like in Hanoi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+<|reserved_special_token_250|>what's the weather like in Hanoi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
>>>get_weather
{"location": "Hanoi"}<|eot_id|><|start_header_id|>tool<|end_header_id|>
@@ -75,4 +74,4 @@ l<|eot_id|><|start_header_id|>tool<|end_header_id|>
[0,1,2,3,5,]<|eot_id|><|start_header_id|>assistant<|end_header_id|>
>>>all
-The final list is: 0,1,2,3,5<|eot_id|>
+The final list is: 0,1,2,3,5<|eot_id|>
\ No newline at end of file
diff --git a/tests/test_case_vision.json b/tests/test_case_vision.json
new file mode 100644
index 0000000..d858d88
--- /dev/null
+++ b/tests/test_case_vision.json
@@ -0,0 +1,172 @@
+{
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_car_price",
+ "description": "Get the price of a particular car model",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "car_name": {
+ "type": "string",
+ "description": "The name of the car model"
+ }
+ },
+ "required": [
+ "car_name"
+ ]
+ }
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "get the weather of a location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "where to get weather"
+ }
+ },
+ "required": [
+ "location"
+ ]
+ }
+ }
+ },
+ {
+ "type": "code_interpreter"
+ }
+
+ ],
+ "messages": [
+ {
+ "role": "user",
+ "content": "who is the CEO of Meetkai"
+ },
+ {
+ "role": "assistant",
+ "content": "James Kaplan is the Co-Founder and CEO of MeetKai Inc."
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "is the car Song more expensive than car Tang?"},
+ {"type": "image_url", "image_url": "song_pic.png"},
+ {"type": "image_url", "image_url": "tang_pic.png"}
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": "I will get the price of 2 cars and compare",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "get_car_price",
+ "arguments": "{\"car_name\": \"Song\"}"
+ }
+ },
+ {
+ "function": {
+ "name": "get_car_price",
+ "arguments": "{\"car_name\": \"Tang\"}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "content": "{\"price\": {\"price\": \"$25000\"}}",
+ "name": "get_car_price"
+ },
+ {
+ "role": "tool",
+ "content": "{\"price\": {\"price\": \"$20000\"}}",
+ "name": "get_car_price"
+ },
+ {
+ "role": "assistant",
+ "content": "No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000."
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": "hanoi.png"},
+ {"text": "what's the weather like in Hanoi?", "type": "text"}
+ ],
+ "metainfo": {
+ "img_path": "IMG_PATH"
+ }
+ },
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {
+ "function": {
+ "name": "get_weather",
+ "arguments": "{\"location\": \"Hanoi\"}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "content": "{\"result\": {\"temperature\": 10}}",
+ "name": "get_weather"
+ },
+ {
+ "role": "assistant",
+ "content": "The temperature in Hanoi is: 10 degree Celcious"
+ },
+ {
+ "role": "user",
+ "content": "Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most"
+ },
+ {
+ "role": "assistant",
+ "content": "I will use code interpreter to handle this",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "python",
+ "arguments": "l=[0,1,2,3,4,5]\nl.remove(3.6)"
+ }
+ }
+ ],
+ "metadata": {
+ "masked": true
+ }
+ },
+ {
+ "role": "tool",
+ "name": "python",
+ "content": "ValueError: list.remove(x): x not in list"
+ },
+ {
+ "role": "assistant",
+ "content": "I will fix the code",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "python",
+ "arguments": "l=[0,1,2,3,4,5]\nl.remove(4)\nl"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "name": "python",
+ "content": "[0,1,2,3,5,]"
+ },
+ {
+ "role": "assistant",
+ "content": "The final list is: 0,1,2,3,5"
+ }
+ ]
+}
diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py
index 08a8bfb..f741982 100644
--- a/tests/test_prompt_creation.py
+++ b/tests/test_prompt_creation.py
@@ -37,13 +37,15 @@ class TestPromptTemplate(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestPromptTemplate, self).__init__(*args, **kwargs)
- self.template_versions = ["v2", "v2.llama3", "v3.llama3", "v3-llama3.1"]
- self.pretrained_models = [
- "meetkai/functionary-small-v2.4",
- "meetkai/functionary-small-v2.5",
- "meetkai/functionary-medium-v3.0",
- "meetkai/functionary-small-v3.1",
- ]
+ self.template_version_to_model_name = {
+ "v2": "meetkai/functionary-small-v2.4",
+ "v2.llama3": "meetkai/functionary-small-v2.5",
+ "v3.llama3": "meetkai/functionary-medium-v3.0",
+ "v3-llama3.1": "meetkai/functionary-small-v3.1",
+ }
+ self.image_template_version_to_model_name = {
+ "v3.llava_llama": "meetkai/functionary-vision-small-v0.1"
+ }
def read_example_data(self, template_version: str):
current_folder = os.path.dirname(os.path.abspath(__file__))
@@ -58,8 +60,19 @@ def read_example_data(self, template_version: str):
final_prompt = final_prompt.replace("\n\n<|from|>", "\n<|from|>")
return test_case, final_prompt
+ def read_image_example_data(self, template_version: str):
+ current_folder = os.path.dirname(os.path.abspath(__file__))
+ with open(os.path.join(current_folder, f"test_case_vision.json")) as f:
+ test_case = json.loads(f.read())
+
+ with open(
+ os.path.join(current_folder, f"prompt_test_{template_version}.txt")
+ ) as f:
+ final_prompt = f.read()
+ return test_case, final_prompt
+
def test_final_prompt_generation(self):
- for template_version in self.template_versions:
+ for template_version in self.template_version_to_model_name.keys():
print("--------------test template_version: ", template_version)
test_case, final_prompt = self.read_example_data(template_version)
tools_or_functions = (
@@ -76,10 +89,30 @@ def test_final_prompt_generation(self):
f"wrong final prompt from: get_prompt_from_messages, for version={template_version}",
)
+ for image_template_version in self.image_template_version_to_model_name.keys():
+ print("--------------test image template_version: ", image_template_version)
+ test_case, final_prompt = self.read_image_example_data(
+ image_template_version
+ )
+ tools_or_functions = (
+ test_case["tools"] if "tools" in test_case else test_case["functions"]
+ )
+ prompt_template = get_prompt_template_by_version(image_template_version)
+ created_prompt = prompt_template.get_prompt_from_messages(
+ test_case["messages"], tools_or_functions
+ )
+ print(created_prompt)
+ self.assertEqual(
+ final_prompt.strip(),
+ created_prompt.strip(),
+ f"wrong final prompt for vision from: get_prompt_from_messages, for version={image_template_version}",
+ )
+
def test_prepare_training_inputs_normal_tokenizer(self):
- for template_version, pretrained_model in zip(
- self.template_versions, self.pretrained_models
- ):
+ for (
+ template_version,
+ pretrained_model,
+ ) in self.template_version_to_model_name.items():
print(f"-------------_TEST: {template_version}, {pretrained_model}")
self.run_prepare_training_inputs(
template_version=template_version,
@@ -152,16 +185,19 @@ def run_prepare_training_inputs(
print(f"number of unmasked chunks: {len(chunks)}")
for chunk, message in zip(chunks, assistant_message):
+ sys_msg = prompt_template.get_prompt_from_messages([])
if keep_assistant_prefix:
prefix = ""
else:
- prefix = prompt_template.convert_message_to_prompt(
- {"role": "assistant"}
+ prefix = prompt_template.get_prompt_from_messages(
+ [], add_generation_prompt=True
)
+ prefix = prefix[len(sys_msg) :].lstrip()
decoded_content = prefix + tokenizer.decode(
chunk
) # note that need to add: "\nassistant" because we mask this, see line 194 in prompt_utils.py
- prompt = prompt_template.convert_message_to_prompt(message)
+ prompt = prompt_template.get_prompt_from_messages([message])
+ prompt = prompt[len(sys_msg) :].lstrip()
# decoded_content and prompt should be the same
# to avoid any mistakes of tokenizer like adding white space we will compare after removing space
self.assertEqual(