diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 98d7b1f..dc45c08 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -1,7 +1,7 @@ import re from typing import Any, List -from functionary.prompt_template.base_template import SYSTEM_MESSAGE, PromptTemplate +from functionary.prompt_template.base_template import PromptTemplate from functionary.prompt_template.llama3_prompt_template import Llama3Template from functionary.prompt_template.llama3_prompt_template_v3 import Llama3TemplateV3 from functionary.prompt_template.llama31_prompt_template import Llama31Template diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index 210b62c..da2548e 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -5,17 +5,21 @@ from abc import abstractmethod from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import jinja2 + from functionary.openai_types import Function, Tool from functionary.prompt_template import prompt_utils from functionary.schema import generate_schema_from_functions -SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" -PYTHON_RUN_SYS_MSG = "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files." - class PromptTemplate: + _jinja_env = jinja2.Environment() + _jinja_env.policies["json.dumps_kwargs"] = {"sort_keys": False} # Mapping from class --> instance to create singleton instance _instances = {} + + def __init__(self): + self._jinja_template = self._jinja_env.from_string(self.get_chat_template_jinja()) @abstractmethod def get_start_of_function_call_token(self) -> str: @@ -53,18 +57,6 @@ def get_additional_tokens(self) -> List[str]: """ raise NotImplementedError - @abstractmethod - def convert_message_to_prompt(self, message: Dict) -> str: - """Return the prompt of this message - - Args: - message (Dict): Dictionary of openAI format - - Returns: - str: prompt of this message - """ - raise NotImplementedError - @abstractmethod def get_stop_tokens_for_generation(self) -> List[str]: """Function to get list of stop tokens in generation @@ -98,47 +90,12 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di """ return messages - def inject_system_messages_based_on_tools( - self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None - ) -> List[Dict]: - """This will be used to add Default system message, code-interpreter system message if needed - - Args: - messages (List[Dict]): List of messages - tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None. - - Returns: - List[Dict]: _description_ - """ - messages_clone = messages.copy() # To avoid modifying the original list - - functions = [] - is_code_interpreter = False - if tools_or_functions is not None: - for item in tools_or_functions: - if ( - "function" in item and item["function"] is not None - ): # new data format: tools: [{"type": xx, "function": xxx}] - functions.append(item["function"]) - elif "type" in item and item["type"] == "code_interpreter": - is_code_interpreter = True - else: - functions.append(item) # old format - - messages_clone.insert( - 0, {"role": "system", "content": generate_schema_from_functions(functions)} - ) - if is_code_interpreter: - messages_clone.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG}) - else: - messages_clone.insert(1, {"role": "system", "content": SYSTEM_MESSAGE}) - - return messages_clone - def get_prompt_from_messages( self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None, + bos_token: Optional[str] = "", + add_generation_prompt: bool = False, ) -> str: """This function is used to get the complete prompt for list of messages @@ -149,14 +106,15 @@ def get_prompt_from_messages( Returns: str: the prompt for inference/training """ - messages_clone = self.inject_system_messages_based_on_tools( - messages, tools_or_functions + + prompt = self._jinja_template.render( + messages=messages, + tools=tools_or_functions, + bos_token=bos_token, + add_generation_prompt=add_generation_prompt, ) - full_text = "" - for message in messages_clone: - full_text += self.convert_message_to_prompt(message) - return full_text # Do not strip because llama3 uses: \n\n before content + return prompt def get_end_token_to_token_id(self, tokenizer: Any) -> Dict[str, int]: """return a dictionary mapping from end_token --> token_id @@ -360,11 +318,13 @@ def get_raw_response_from_assistant_message( str: The mock raw response in str format """ # Form raw response from messages list - raw_response = self.convert_message_to_prompt(message) - - # Remove null content - null_content = self.convert_message_to_prompt({"role": "assistant"}) - raw_response = raw_response[len(null_content) :] + sys_msg = self.get_prompt_from_messages( + messages=[], tools_or_functions=[], add_generation_prompt=True + ) + assistant_msg = self.get_prompt_from_messages( + messages=[message], tools_or_functions=[] + ) + raw_response = assistant_msg[len(sys_msg) :] # Remove stop tokens for stop_token in self.get_stop_tokens_for_generation(): @@ -375,9 +335,19 @@ def get_raw_response_from_assistant_message( return raw_response.rstrip() - def get_chat_template_jinja(self): - """Return chat_template in jinja format""" - return "{# " + f"version={self.version}" + " #}" + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}json_to_ts_schema.txt", "r") as f: + json_to_ts_schema = f.read() + with open(f"{path_prefix}{self.version}.txt", "r") as f: + template = f.read() + + return ( + template[: template.index("{%")] + + json_to_ts_schema + + "\n" + + template[template.index("{%") :] + ) def get_generation_prefix_for_tool_choice(self, tool_choice: Any): if tool_choice == "auto" or tool_choice is None: diff --git a/functionary/prompt_template/jinja_templates/json_to_ts_schema.txt b/functionary/prompt_template/jinja_templates/json_to_ts_schema.txt new file mode 100644 index 0000000..aab27db --- /dev/null +++ b/functionary/prompt_template/jinja_templates/json_to_ts_schema.txt @@ -0,0 +1,261 @@ +{%- macro append_new_param_info(param_declaration, comment_info, examples_info, depth) -%} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- if examples_info | length > 0 -%} + {# Append each example info #} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- endif -%} + {{ "\n" + offset + param_declaration }} +{%- endmacro -%} + +{%- macro convert_data_type(param_type) -%} + {%- if param_type == "integer" or param_type == "float" -%} + {{ "number" }} + {%- else -%} + {{ param_type }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_type(param) -%} + {%- set param_type = "any" -%} + + {%- if "type" in param -%} + {%- set raw_param_type = param["type"] -%} + {%- if raw_param_type is iterable and raw_param_type is not string -%} + {%- set param_type = raw_param_type | join(" | ") -%} + {%- else -%} + {%- set param_type = raw_param_type -%} + {%- endif -%} + {{ convert_data_type(param_type) }} + {%- elif "oneOf" in param -%} + {%- set one_of_types = param["oneOf"]|selectattr("type", "defined")|list -%} + {%- set one_of_types = one_of_types|map(attribute="type")|unique|list -%} + {{ convert_data_type(one_of_types | join(" | ")) }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_format_param(param) -%} + {%- if "format" in param -%} + {{ param["format"] }} + {%- elif "oneOf" in param -%} + {%- set formats = [] -%} + {%- for item in param["oneOf"] -%} + {%- if "format" in item -%} + {%- if item["format"] == param["oneOf"][-1]["format"] -%} + {{ item["format"] }} + {%- else -%} + {{ item["format"] + " or "}} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_info(param) -%} + {%- set param_type = param.get("type", "any") -%} + {%- set format_param = get_format_param(param) -%} + + {%- if "description" in param or "default" in param or format_param != "<|NONE|>" or param["maximum"] or param["minimum"] or param["maxLength"] or param["minLength"] -%} + {{ "//" }} + {%- if "description" in param -%} + {%- set desc = param["description"] -%} + {%- if not desc.endswith(".") -%} + {%- set desc = desc + "." -%} + {%- endif -%} + {{ " " + desc }} + {%- endif -%} + + {%- if "default" in param -%} + {%- set default_value = param["default"] -%} + {%- if param_type == "string" -%} + {%- set default_value = '"' ~ default_value ~ '"' -%} + {%- endif -%} + {{ " Default=" ~ default_value ~ "." }} + {%- endif -%} + + {%- set format_param = get_format_param(param) -%} + {%- if format_param != "<|NONE|>" -%} + {{ " Format=" ~ format_param }} + {%- endif -%} + + {%- for field, field_name in [("maximum", "Maximum"), ("minimum", "Minimum"), ("maxLength", "Maximum length"), ("minLength", "Minimum length")] -%} + {%- if field in param -%} + {{ " " + field_name ~ "=" ~ param[field] }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>"}} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_enum_option_str(enum_options) -%} + {%- for v in enum_options -%} + {%- if v is string -%} + {{ '"' + v + '"' }} + {%- else -%} + {{ v }} + {%- endif -%} + {%- if enum_options|length > 0 and v != enum_options[-1] -%} + {{ " | " }} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro get_array_typescript(param_name, param_dic, depth) -%} + {%- set offset = '' -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- set items_info = param_dic.get('items', {}) -%} + + {%- if items_info|length == 0 -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": []" }} + {%- else -%} + {{ "\n" + offset + "[]" }} + {%- endif -%} + {%- else -%} + {%- set array_type = get_param_type(items_info) -%} + {%- if array_type == 'object' -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": {" }} + {%- else -%} + {{ "\n" + offset + "{" }} + {%- endif -%} + {{ get_parameter_typescript(items_info.get('properties', {}), items_info.get('required', []), depth + 1) -}} + {{- "\n" + offset + "}[]" }} + {%- elif array_type == 'array' -%} + {%- set item_info = get_array_typescript(None, items_info, depth + 1) -%} + {%- if not param_name -%} + {{ "\n" + item_info + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + item_info|trim + "[]" }} + {%- endif -%} + {%- else -%} + {%- if 'enum' in items_info -%} + {%- set item_type = get_enum_option_str(items_info['enum']) -%} + {%- if param_name is none -%} + {{ "(" + item_type + ")[]"}} + {%- else -%} + {{ "\n" + offset + param_name + ": (" + item_type + ")[]" }} + {%- endif -%} + {%- else -%} + {%- if param_name is none -%} + {{ "\n" + array_type + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + array_type + "[]," }} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_parameter_typescript(properties, required_params, depth=0) -%} + {%- set res = "" -%} + {%- for param_name, param in properties.items() -%} + {%- if param is mapping -%} + {%- set comment_info = get_param_info(param) -%} + {# Param Examples #} + {%- set examples_info = [] -%} + {%- if "examples" in param -%} + {%- set examples_info = ["Example " + param_name + ":"] -%} + {%- set examples_info = examples_info + param["examples"] -%} + {%- endif -%} + + {# Param Name declaration #} + {%- set param_declaration = param_name -%} + {%- if required_params is iterable and param_name not in required_params -%} + {%- set param_declaration = param_declaration + "?" -%} + {%- endif -%} + + {%- set param_type = get_param_type(param) -%} + + {# Handle indentation based on depth #} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + + {%- if param_type == "object" -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": {" -%} + {{ "\n" + offset + param_declaration -}} + {{- get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1) -}} + {{- "\n" + offset + "}," }} + {%- elif param_type == "array" -%} + {%- set item_info = param.get("items", {}) -%} + {%- if "type" not in item_info -%} + {%- set param_declaration = param_declaration + ": []," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- else -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set array_declaration = get_array_typescript(param_declaration, param, depth) -%} + {%- if not array_declaration.endswith(",") -%} + {%- set array_declaration = array_declaration + "," -%} + {%- endif -%} + {{ array_declaration}} + {%- endif -%} + {%- else -%} + {%- if "enum" in param -%} + {%- set param_type = get_enum_option_str(param["enum"]) -%} + {%- endif -%} + {%- if "nullable" in param and param["nullable"] -%} + {%- set param_type = param_type + " | null" -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": " + param_type + "," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro generate_schema_from_functions(functions, namespace='functions') -%} + {{ "// Supported function definitions that should be called when necessary.\n" -}} + {{- "namespace " + namespace + " {\n\n" -}} + + {%- for function in functions -%} + {%- if function.get("function") -%} + {%- set function = function.get("function") -%} + {%- endif -%} + + {%- set function_name = function.get("name") -%} + {%- if function_name -%} + {%- set description = function.get('description', '') -%} + {%- set parameters = function.get('parameters', {}) -%} + {{- "// " + description + "\n" -}} + {{- "type " + function_name -}} + {%- if parameters and parameters.get("properties") -%} + {{- " = (_: {" -}} + {%- set required_params = parameters.get("required", []) -%} + {{ get_parameter_typescript(parameters.get("properties"), required_params, 0) -}} + {{- "\n}) => any;\n\n" }} + {%- else -%} + {{ " = () => any;\n\n" }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {{ "} // namespace " + namespace }} +{%- endmacro -%} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/v2.llama3.txt b/functionary/prompt_template/jinja_templates/v2.llama3.txt new file mode 100644 index 0000000..4796a9a --- /dev/null +++ b/functionary/prompt_template/jinja_templates/v2.llama3.txt @@ -0,0 +1,28 @@ +{# version=v2.llama3 #}{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }} +{%- else -%} + {{ "<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary<|eot_id|>" }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + 'name=' + message['name'] + '\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '<|reserved_special_token_249|>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/v2.txt b/functionary/prompt_template/jinja_templates/v2.txt new file mode 100644 index 0000000..ff7dc86 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/v2.txt @@ -0,0 +1,27 @@ +{# version=v2 #}{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|from|>system\n<|recipient|>all\n<|content|>' + generate_schema_from_functions(tools) -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '\n<|from|>system\n<|recipient|>all\n<|content|>When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.' }} +{%- else -%} + {{ "\n<|from|>system\n<|recipient|>all\n<|content|>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary" }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '\n<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] }} + {%- elif message['role'] == 'tool' -%} + {{ '\n<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] }} + {%- else -%} + {%- if message['content'] -%} + {{ "\n<|from|>" + message['role'] + "\n<|recipient|>all\n<|content|>" + message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '\n<|from|>' + message['role'] + '\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ "<|stop|>" }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '\n<|from|>assistant\n<|recipient|>' }}{% endif %} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/v3-llama3.1.txt b/functionary/prompt_template/jinja_templates/v3-llama3.1.txt new file mode 100644 index 0000000..29d64a2 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/v3-llama3.1.txt @@ -0,0 +1,58 @@ +{# version=v3-llama3.1 #}{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} + +{#- System message + builtin tools #} +{{- bos_token + "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if has_code_interpreter %} + {{- "Environment: ipython\n\n" }} +{%- else -%} + {{ "\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n\n" }} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '"|safe + t["function"]["name"] + "' to '"|safe + t["function"]["description"] + "'\n"|safe + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '"|safe + t["name"] + "' to '"|safe + t["description"] + "'\n"|safe + t | tojson() }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with \n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}} +{%- endif %} +{{- "<|eot_id|>" -}} + +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>ipython<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {{ '<|python_tag|>' + tool_call['function']['arguments'] }} + {%- else -%} + {{ '' + tool_call['function']['arguments'] + '' }} + {%- endif -%} + {%- endfor -%} + {{ '<|eom_id|>' }} + {%- else -%} + {{ '<|eot_id|>' }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif -%} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/v3.llama3.txt b/functionary/prompt_template/jinja_templates/v3.llama3.txt new file mode 100644 index 0000000..5a1d8a5 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/v3.llama3.txt @@ -0,0 +1,26 @@ +{# version=v3.llama3 #}{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ '>>>all\n' + message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/v3.llava_llama.txt b/functionary/prompt_template/jinja_templates/v3.llava_llama.txt new file mode 100644 index 0000000..c3abf5a --- /dev/null +++ b/functionary/prompt_template/jinja_templates/v3.llava_llama.txt @@ -0,0 +1,40 @@ +{# version=v3.llava_llama #}{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'user' -%} + {%- if message['content'] is string -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {{ content['text'] }} + {%- else -%} + {{ '<|reserved_special_token_250|>' }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ '>>>all\n' + message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %} \ No newline at end of file diff --git a/functionary/prompt_template/llama31_prompt_template.py b/functionary/prompt_template/llama31_prompt_template.py index cff00a2..c237ce1 100644 --- a/functionary/prompt_template/llama31_prompt_template.py +++ b/functionary/prompt_template/llama31_prompt_template.py @@ -8,69 +8,6 @@ from functionary.prompt_template.base_template import PromptTemplate -def get_system_prompt_for_custom_tools(custom_tools: List) -> str: - custom_tool_params = "" - for t in custom_tools: - custom_tool_params += get_instruction_string(t) + "\n" - custom_tool_params += get_parameters_string(t) + "\n\n" - - content = f""" -You have access to the following functions: - -{custom_tool_params} -Think very carefully before calling functions. -If a you choose to call a function ONLY reply in the following format: -<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}} -where - -start_tag => ` a JSON dict with the function argument name as key and function argument value as value. -end_tag => `` - -Here is an example, -{{"example_name": "example_value"}} - -Reminder: -- If looking for real time information use relevant functions before falling back to brave_search -- Function calls MUST follow the specified format, start with -- Required parameters MUST be specified -- Only call one function at a time -- Put the entire function call reply on one line - -""" - return content - - -def get_instruction_string(custom_tool_definition) -> str: - name, description = ( - custom_tool_definition["name"], - custom_tool_definition["description"], - ) - return f"Use the function '{name}' to '{description}'" - - -def get_parameters_string(custom_tool_definition) -> str: - return json.dumps(custom_tool_definition) - - -def get_system_message_for_tools(tools: List[Dict], use_code_interpreter) -> List[Dict]: - content = "" - if use_code_interpreter: - content += "Environment: ipython\n" - - current_date = datetime.datetime.now() - formatted_date = current_date.strftime("%d %B %Y") - date_str = f""" -Cutting Knowledge Date: December 2023\n\n""" - content += date_str - - if tools: - custom_message = get_system_prompt_for_custom_tools(tools) - content += custom_message - - return {"role": "system", "content": content} - - def parse_function_call_from_text(function_call_text: str) -> Optional[Dict]: index = function_call_text.find(">") if index >= 0: @@ -109,87 +46,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di def get_stop_tokens_for_generation(self) -> List[str]: return [self.eos_token, "<|end_of_text|>", self.eof_message] - def inject_system_messages_based_on_tools( - self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None - ) -> List[Dict]: - """This will be used to add Default system message, code-interpreter system message if needed - - Args: - messages (List[Dict]): List of messages - tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None. - - Returns: - List[Dict]: _description_ - """ - messages_clone = messages.copy() # To avoid modifying the original list - - functions = [] - is_code_interpreter = False - if tools_or_functions is not None: - for item in tools_or_functions: - if ( - "function" in item and item["function"] is not None - ): # new data format: tools: [{"type": xx, "function": xxx}] - functions.append(item["function"]) - elif "type" in item and item["type"] == "code_interpreter": - is_code_interpreter = True - else: - functions.append(item) # old format - - tools_system_message = get_system_message_for_tools( - functions, is_code_interpreter - ) - messages_clone.insert(0, tools_system_message) - - return messages_clone - - def convert_message_to_prompt(self, message: Dict) -> str: - role = message["role"] - if role == "tool": - role = "ipython" - content = message.get("content", None) - - prompt_template = ( - f"{self.start_header}{role}{self.end_header}\n\n" + "{text}{eot_content}" - ) - eot_content = self.eos_token - - if role in ["user", "system", "ipython"]: - return prompt_template.format(text=content, eot_content=eot_content) - - assert role == "assistant", f"role must be assistant, but: {role}" - - # set content=none if content="" - if type(content) is str and len(content) == 0: - content = None - - tool_calls = message.get("tool_calls", []) - if tool_calls is None: - tool_calls = [] - - if content is None and len(tool_calls) == 0: # inference time - return f"{self.start_header}{role}{self.end_header}\n\n" - - total_content = content if content else "" - - # list of text representing function calls: {function_name}\n{arguments} - tool_call_prompts = [] - for tool_call in tool_calls: - arguments = tool_call["function"]["arguments"] - assert isinstance(arguments, str) - tool_name = tool_call["function"]["name"] - if tool_name == "python": - tool_prompt = f"<|python_tag|>{arguments}" - else: - tool_prompt = f"{arguments}" - tool_call_prompts.append(tool_prompt) - - # join all function calls - if tool_call_prompts: - total_content += "".join(tool_call_prompts) - eot_content = self.eof_message - return prompt_template.format(text=total_content, eot_content=eot_content) - def parse_assistant_response( self, llm_output: str, tool_choice: Any = None ) -> Dict: diff --git a/functionary/prompt_template/llama3_prompt_template.py b/functionary/prompt_template/llama3_prompt_template.py index 5042acb..ef882c8 100644 --- a/functionary/prompt_template/llama3_prompt_template.py +++ b/functionary/prompt_template/llama3_prompt_template.py @@ -8,22 +8,6 @@ from functionary.prompt_template.base_template import PromptTemplate -def convert_to_llama3_messages(messages: List[Dict]) -> List[Dict]: - result = [] - index = 0 - while index < len(messages): - if messages[index]["role"] in ["user", "system"]: - result.append(messages[index]) - index += 1 - else: - if messages[index]["role"] == "assistant": - tool_calls = messages[index].get("tool_calls", []) - if len(tool_calls) == 0: - result.append(messages[index]) - else: - messages - - class Llama3Template(PromptTemplate): function_separator = "<|reserved_special_token_249|>" version = "v2.llama3" @@ -172,58 +156,6 @@ def parse_assistant_response( tool_calls = None if len(tool_calls) == 0 else tool_calls return {"role": "assistant", "content": text_content, "tool_calls": tool_calls} - def convert_message_to_prompt(self, message: Dict) -> str: - role = message["role"] - content = message.get("content", None) - if role == "tool": - tool_name = message["name"] - content = f"name={tool_name}\n{content}" - - prompt_template = ( - "<|start_header_id|>%s<|end_header_id|>\n\n{text}<|eot_id|>" % role - ) - - if role in ["user", "system", "tool"]: - return prompt_template.format(text=content) - - assert role == "assistant" - # set content=none if content="" - if type(content) is str and len(content) == 0: - content = None - - tool_calls = message.get("tool_calls", []) - if tool_calls is None: - tool_calls = [] - - if content is None and len(tool_calls) == 0: - return f"<|start_header_id|>{role}<|end_header_id|>\n\n" - - if content is not None and len(tool_calls) == 0: # text-only - return prompt_template.format(text=content) - - tool_call_prompts = [] - for tool_call in tool_calls: - arguments = tool_call["function"]["arguments"] - tool_name = tool_call["function"]["name"] - tool_prompt = f"{tool_name}\n{arguments}" - tool_call_prompts.append(tool_prompt) - - if ( - content is None and len(tool_calls) > 0 - ): # function call only (fc1fc2) - tool_call_content = self.function_separator + self.function_separator.join( - tool_call_prompts - ) - return prompt_template.format(text=tool_call_content) - - # Here is the case contains both text-response and tool_calls (contentfc1fc2) - total_content = ( - content - + self.function_separator - + self.function_separator.join(tool_call_prompts) - ) - return prompt_template.format(text=total_content) - def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: """Order the tool results by the order of tool call ids @@ -235,18 +167,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di """ return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) - def update_state_for_function(self, current_state): - """update the state when a function is going to be called - - Args: - current_state (_type_): _description_ - """ - current_state["response_type"] = "function" - current_state["skip_until_reach"] = "\n" - current_state["current_text"] = "" - current_state["func_index"] += 1 - current_state["call_id"] = prompt_utils.get_random_tool_call_id() - def initialize_fsm_gen_state( self, tool_choice: Union[str, Tool], @@ -441,32 +361,6 @@ def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List): return options - def get_chat_template_jinja(self) -> str: - chat_template = """{% for message in messages %} - {% if message['role'] == 'user' or message['role'] == 'system' %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% elif message['role'] == 'tool' %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + 'name=' + message['name'] + '\n' + message['content'] + '<|eot_id|>' }}
- {% else %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
- {% if message['content'] is not none %} - {{ message['content'] }}
- {% endif %} - {% if 'tool_calls' in message and message['tool_calls'] is not none %} - {% for tool_call in message['tool_calls'] %} - {{ '<|reserved_special_token_249|>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
- {% endfor %} - {% endif %} - {{ '<|eot_id|>' }}
- {% endif %} - {% endfor %} - {% if add_generation_prompt %}{{ '<|start_header_id|>{role}<|end_header_id|>\n\n' }}{% endif %} - """ - chat_template = chat_template.replace(" ", "") - chat_template = chat_template.replace("
\n", "") - chat_template = chat_template.strip() - return chat_template - def get_force_text_generation_prefix(self): return "" diff --git a/functionary/prompt_template/llama3_prompt_template_v3.py b/functionary/prompt_template/llama3_prompt_template_v3.py index 91af993..a5793a7 100644 --- a/functionary/prompt_template/llama3_prompt_template_v3.py +++ b/functionary/prompt_template/llama3_prompt_template_v3.py @@ -3,18 +3,7 @@ from functionary.openai_types import Function, Tool from functionary.prompt_template import prompt_utils -from functionary.prompt_template.base_template import PYTHON_RUN_SYS_MSG, PromptTemplate -from functionary.schema import generate_schema_from_functions - -SYSTEM_CONTENT = """You are capable of executing available function(s) if required. -Only execute function(s) when absolutely necessary. -Ask for the required input to:recipient==all -Use JSON for function arguments. -Respond in this format: ->>>${recipient} -${content} -Available functions: -""" +from functionary.prompt_template.base_template import PromptTemplate class Llama3TemplateV3(PromptTemplate): @@ -147,56 +136,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di """ return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) - def convert_message_to_prompt(self, message: Dict) -> str: - role = message["role"] - content = message.get("content", None) - - # comment this as currently the Llama-70b was trained using this - # if role == "tool": - # tool_name = message["name"] - # content = f"name={tool_name}\n{content}" - - prompt_template = ( - f"{self.start_header}{role}{self.end_header}\n\n" - + "{text}" - + self.eos_token - ) - - if role in ["user", "system", "tool"]: - return prompt_template.format(text=content) - - assert role == "assistant", f"role must be assistant, but: {role}" - - # set content=none if content="" - if type(content) is str and len(content) == 0: - content = None - - tool_calls = message.get("tool_calls", []) - if tool_calls is None: - tool_calls = [] - - if content is None and len(tool_calls) == 0: # inference time - return f"{self.start_header}{role}{self.end_header}\n\n{self.function_separator}" - - if content is not None: # text-only - tool_calls = [ - {"function": {"name": "all", "arguments": content}} - ] + tool_calls - - # list of text representing function calls: {function_name}\n{arguments} - tool_call_prompts = [] - for tool_call in tool_calls: - arguments = tool_call["function"]["arguments"] - tool_name = tool_call["function"]["name"] - tool_prompt = f"{tool_name}\n{arguments}" - tool_call_prompts.append(tool_prompt) - - # join all function calls - total_content = self.function_separator + self.function_separator.join( - tool_call_prompts - ) - return prompt_template.format(text=total_content) - def parse_assistant_response( self, llm_output: str, tool_choice: Any = None ) -> Dict: @@ -235,45 +174,6 @@ def parse_assistant_response( return {"role": "assistant", "content": text_content, "tool_calls": tool_calls} - def inject_system_messages_based_on_tools( - self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None - ) -> List[Dict]: - """This will be used to add Default system message, code-interpreter system message if needed - - Args: - messages (List[Dict]): List of messages - tools_or_functions (Optional[List[Dict]], optional): List of tools, functions. Defaults to None. - - Returns: - List[Dict]: _description_ - """ - messages_clone = messages.copy() # To avoid modifying the original list - - functions = [] - is_code_interpreter = False - if tools_or_functions is not None: - for item in tools_or_functions: - if ( - "function" in item and item["function"] is not None - ): # new data format: tools: [{"type": xx, "function": xxx}] - functions.append(item["function"]) - elif "type" in item and item["type"] == "code_interpreter": - is_code_interpreter = True - else: - functions.append(item) # old format - - messages_clone.insert( - 0, - { - "role": "system", - "content": SYSTEM_CONTENT + generate_schema_from_functions(functions), - }, - ) - if is_code_interpreter: - messages_clone.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG}) - - return messages_clone - def get_force_text_generation_prefix(self): return f"all\n" @@ -478,29 +378,3 @@ def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List): options = [self.fn_param_sep_token] return options - - def get_chat_template_jinja(self) -> str: - chat_template = """{% for message in messages %} - {% if message['role'] == 'user' or message['role'] == 'system' %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% elif message['role'] == 'tool' %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% else %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
- {% if message['content'] is not none %} - {{ '>>>all\n' + message['content'] }}
- {% endif %} - {% if 'tool_calls' in message and message['tool_calls'] is not none %} - {% for tool_call in message['tool_calls'] %} - {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
- {% endfor %} - {% endif %} - {{ '<|eot_id|>' }}
- {% endif %} - {% endfor %} - {% if add_generation_prompt %}{{ '<|start_header_id|>{role}<|end_header_id|>\n\n' }}{% endif %} - """ - chat_template = chat_template.replace(" ", "") - chat_template = chat_template.replace("
\n", "") - chat_template = chat_template.strip() - return chat_template diff --git a/functionary/prompt_template/llava_prompt_template.py b/functionary/prompt_template/llava_prompt_template.py index 1a3b64a..bc9f7e2 100644 --- a/functionary/prompt_template/llava_prompt_template.py +++ b/functionary/prompt_template/llava_prompt_template.py @@ -9,52 +9,3 @@ class LlavaLlama(Llama3TemplateV3): version = "v3.llava_llama" # This token will be replaced with image_token_id (-200) after we tokenize the text image_token = "<|reserved_special_token_250|>" - - def convert_message_to_prompt(self, message: Dict) -> str: - role = message["role"] - content = message.get("content", None) - - # handle the case when user uploads images (content is a list) - if role == "user" and type(content) is list: - text_content = prompt_utils.stringify_content_with_images( - message["content"], self.image_token - ) - return f"{self.start_header}{role}{self.end_header}\n\n{text_content}{self.eos_token}" - return super().convert_message_to_prompt(message) - - def get_chat_template_jinja(self) -> str: - chat_template = """{% for message in messages %} - {% if message['role'] == 'user'%} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
- {% if message['content'] is iterable and (message['content'] is not string and message['content'] is not mapping) %} - {% for content_item in message['content'] %} - {% if content_item['type'] == 'image_url' %} - {{ '<|reserved_special_token_250|>' }}
- {% else %} - {{ content_item['text'] }}
- {% endif %} - {% endfor %} - {% else %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% endif %} - {% elif message['role'] == 'tool' or message['role'] == 'system' %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
- {% else %} - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
- {% if message['content'] is not none %} - {{ '>>>all\n' + message['content'] }}
- {% endif %} - {% if 'tool_calls' in message and message['tool_calls'] is not none %} - {% for tool_call in message['tool_calls'] %} - {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}
- {% endfor %} - {% endif %} - {{ '<|eot_id|>' }}
- {% endif %} - {% endfor %} - {% if add_generation_prompt %}{{ '<|start_header_id|>{role}<|end_header_id|>\n\n' }}{% endif %} - """ - chat_template = chat_template.replace(" ", "") - chat_template = chat_template.replace("
\n", "") - chat_template = chat_template.strip() - return chat_template diff --git a/functionary/prompt_template/prompt_template_v2.py b/functionary/prompt_template/prompt_template_v2.py index 66731cf..64c7b93 100644 --- a/functionary/prompt_template/prompt_template_v2.py +++ b/functionary/prompt_template/prompt_template_v2.py @@ -121,50 +121,6 @@ def get_additional_tokens(self) -> List[str]: self.stop_token, ] - def convert_message_to_prompt(self, message: Dict) -> str: - role = message["role"] - content = message.get("content", None) - - if role in [ - "system", - "user", - ]: # <|from|>system\n<|recipient|>all\n<|content|>xxx - return f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}\n" - - if role == "tool": # <|from|>tool_name\n<|recipient|>all\n<|content|>xxx - tool_name = message["name"] - return f"{self.from_token}{tool_name}\n{self.recipient_token}all\n{self.content_token}{content}\n" - - assert role == "assistant" - - # set content=none if content="" - if type(content) is str and len(content) == 0: - content = None - - tool_calls = message.get("tool_calls", []) - if tool_calls is None: - tool_calls = [] - if ( - len(tool_calls) == 0 and content is None - ): # for inference: <|from|> assistant\n<|recipient|> - return f"{self.from_token}{role}\n{self.recipient_token}" - - if len(tool_calls) == 0: # <|from|>assistant\n<|recipient|>all\n<|content|>xxx - return f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}{self.stop_token}\n" - - result = "" - if content is not None: # both text-response and function_call - result += f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}\n" - - for tool in tool_calls: - func_name = tool["function"]["name"] - arguments = tool["function"]["arguments"] - # <|from|>assistant\n<|recipient|>func_name\n<|content|>xxxx - result += f"{self.from_token}{role}\n{self.recipient_token}{func_name}\n{self.content_token}{arguments}\n" - - result = result.strip() + f"{self.stop_token}\n" - return result - def get_stop_tokens_for_generation(self) -> List[str]: return [self.stop_token] @@ -252,38 +208,6 @@ def get_recipient(self, current_text: str) -> str: end_index = current_text.find(f"\n{self.content_token}") return current_text[start_index:end_index].strip() - def get_chat_template_jinja(self) -> str: - chat_template = """{% for message in messages %} - {% if message['role'] == 'user' or message['role'] == 'system' %} - {{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}
- {% elif message['role'] == 'tool' %} - {{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}
- {% else %} - {% set contain_content='no'%} - {% if message['content'] is not none %} - {{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}
- {% set contain_content='yes'%} - {% endif %} - {% if 'tool_calls' in message and message['tool_calls'] is not none %} - {% for tool_call in message['tool_calls'] %} - {% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %} - {% if loop.index == 1 and contain_content == "no" %} - {{ prompt }}
- {% else %} - {{ '\n' + prompt}}
- {% endif %} - {% endfor %} - {% endif %} - {{ '<|stop|>\n' }}
- {% endif %} - {% endfor %} - {% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %} - """ - chat_template = chat_template.replace(" ", "") - chat_template = chat_template.replace("
\n", "") - chat_template = chat_template.strip() - return chat_template - def initialize_fsm_gen_state( self, tool_choice: Union[str, Tool], diff --git a/functionary/prompt_template/prompt_utils.py b/functionary/prompt_template/prompt_utils.py index 7ce1a79..600ff98 100644 --- a/functionary/prompt_template/prompt_utils.py +++ b/functionary/prompt_template/prompt_utils.py @@ -1,12 +1,13 @@ +import base64 +import os import random import string -from typing import Dict, List, Optional, Union -from PIL import Image from io import BytesIO -import os -import base64 +from typing import Dict, List, Optional, Union + import requests import torch +from PIL import Image from transformers import LlamaTokenizer from functionary.openai_types import ChatMessage, Function, Tool @@ -80,14 +81,16 @@ def prepare_messages_for_inference( prompt_template = get_prompt_template_from_tokenizer(tokenizer) dic_messages = [mess.dict() for mess in messages] - dic_messages.append({"role": "assistant"}) dic_messages = prompt_template.pre_process_messages_before_inference(dic_messages) # This also checks for code_interpreter and adds python default system message instead # default system message final_prompt = prompt_template.get_prompt_from_messages( - dic_messages, tools_or_functions=tools_or_functions + dic_messages, + tools_or_functions=tools_or_functions, + bos_token="", + add_generation_prompt=True, ) # add prefix based on tool-choice diff --git a/requirements.txt b/requirements.txt index 464789e..4164362 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ protobuf==3.20.0 tokenizers==0.19.1 vllm==0.5.4; sys_platform != "darwin" json_source_map==1.0.5 +jinja2==3.1.4 diff --git a/tests/prompt_test_v3.llava_llama.txt b/tests/prompt_test_v3.llava_llama.txt index 75e78ed..54b85ce 100644 --- a/tests/prompt_test_v3.llava_llama.txt +++ b/tests/prompt_test_v3.llava_llama.txt @@ -32,7 +32,7 @@ who is the CEO of Meetkai<|eot_id|><|start_header_id|>assistant<|end_header_id|> >>>all James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|eot_id|><|start_header_id|>user<|end_header_id|> -is the car Song more expensive than car Tang?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +is the car Song more expensive than car Tang?<|reserved_special_token_250|><|reserved_special_token_250|><|eot_id|><|start_header_id|>assistant<|end_header_id|> >>>all I will get the price of 2 cars and compare>>>get_car_price @@ -46,8 +46,7 @@ I will get the price of 2 cars and compare>>>get_car_price >>>all No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|eot_id|><|start_header_id|>user<|end_header_id|> -<|reserved_special_token_250|> -what's the weather like in Hanoi?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +<|reserved_special_token_250|>what's the weather like in Hanoi?<|eot_id|><|start_header_id|>assistant<|end_header_id|> >>>get_weather {"location": "Hanoi"}<|eot_id|><|start_header_id|>tool<|end_header_id|> @@ -75,4 +74,4 @@ l<|eot_id|><|start_header_id|>tool<|end_header_id|> [0,1,2,3,5,]<|eot_id|><|start_header_id|>assistant<|end_header_id|> >>>all -The final list is: 0,1,2,3,5<|eot_id|> +The final list is: 0,1,2,3,5<|eot_id|> \ No newline at end of file diff --git a/tests/test_case_vision.json b/tests/test_case_vision.json new file mode 100644 index 0000000..d858d88 --- /dev/null +++ b/tests/test_case_vision.json @@ -0,0 +1,172 @@ +{ + "tools": [ + { + "type": "function", + "function": { + "name": "get_car_price", + "description": "Get the price of a particular car model", + "parameters": { + "type": "object", + "properties": { + "car_name": { + "type": "string", + "description": "The name of the car model" + } + }, + "required": [ + "car_name" + ] + } + } + }, + { + "type": "function", + "function": { + "name": "get_weather", + "description": "get the weather of a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "where to get weather" + } + }, + "required": [ + "location" + ] + } + } + }, + { + "type": "code_interpreter" + } + + ], + "messages": [ + { + "role": "user", + "content": "who is the CEO of Meetkai" + }, + { + "role": "assistant", + "content": "James Kaplan is the Co-Founder and CEO of MeetKai Inc." + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "is the car Song more expensive than car Tang?"}, + {"type": "image_url", "image_url": "song_pic.png"}, + {"type": "image_url", "image_url": "tang_pic.png"} + ] + }, + { + "role": "assistant", + "content": "I will get the price of 2 cars and compare", + "tool_calls": [ + { + "function": { + "name": "get_car_price", + "arguments": "{\"car_name\": \"Song\"}" + } + }, + { + "function": { + "name": "get_car_price", + "arguments": "{\"car_name\": \"Tang\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"price\": {\"price\": \"$25000\"}}", + "name": "get_car_price" + }, + { + "role": "tool", + "content": "{\"price\": {\"price\": \"$20000\"}}", + "name": "get_car_price" + }, + { + "role": "assistant", + "content": "No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000." + }, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": "hanoi.png"}, + {"text": "what's the weather like in Hanoi?", "type": "text"} + ], + "metainfo": { + "img_path": "IMG_PATH" + } + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Hanoi\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"result\": {\"temperature\": 10}}", + "name": "get_weather" + }, + { + "role": "assistant", + "content": "The temperature in Hanoi is: 10 degree Celcious" + }, + { + "role": "user", + "content": "Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most" + }, + { + "role": "assistant", + "content": "I will use code interpreter to handle this", + "tool_calls": [ + { + "function": { + "name": "python", + "arguments": "l=[0,1,2,3,4,5]\nl.remove(3.6)" + } + } + ], + "metadata": { + "masked": true + } + }, + { + "role": "tool", + "name": "python", + "content": "ValueError: list.remove(x): x not in list" + }, + { + "role": "assistant", + "content": "I will fix the code", + "tool_calls": [ + { + "function": { + "name": "python", + "arguments": "l=[0,1,2,3,4,5]\nl.remove(4)\nl" + } + } + ] + }, + { + "role": "tool", + "name": "python", + "content": "[0,1,2,3,5,]" + }, + { + "role": "assistant", + "content": "The final list is: 0,1,2,3,5" + } + ] +} diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 08a8bfb..f741982 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -37,13 +37,15 @@ class TestPromptTemplate(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestPromptTemplate, self).__init__(*args, **kwargs) - self.template_versions = ["v2", "v2.llama3", "v3.llama3", "v3-llama3.1"] - self.pretrained_models = [ - "meetkai/functionary-small-v2.4", - "meetkai/functionary-small-v2.5", - "meetkai/functionary-medium-v3.0", - "meetkai/functionary-small-v3.1", - ] + self.template_version_to_model_name = { + "v2": "meetkai/functionary-small-v2.4", + "v2.llama3": "meetkai/functionary-small-v2.5", + "v3.llama3": "meetkai/functionary-medium-v3.0", + "v3-llama3.1": "meetkai/functionary-small-v3.1", + } + self.image_template_version_to_model_name = { + "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" + } def read_example_data(self, template_version: str): current_folder = os.path.dirname(os.path.abspath(__file__)) @@ -58,8 +60,19 @@ def read_example_data(self, template_version: str): final_prompt = final_prompt.replace("\n\n<|from|>", "\n<|from|>") return test_case, final_prompt + def read_image_example_data(self, template_version: str): + current_folder = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(current_folder, f"test_case_vision.json")) as f: + test_case = json.loads(f.read()) + + with open( + os.path.join(current_folder, f"prompt_test_{template_version}.txt") + ) as f: + final_prompt = f.read() + return test_case, final_prompt + def test_final_prompt_generation(self): - for template_version in self.template_versions: + for template_version in self.template_version_to_model_name.keys(): print("--------------test template_version: ", template_version) test_case, final_prompt = self.read_example_data(template_version) tools_or_functions = ( @@ -76,10 +89,30 @@ def test_final_prompt_generation(self): f"wrong final prompt from: get_prompt_from_messages, for version={template_version}", ) + for image_template_version in self.image_template_version_to_model_name.keys(): + print("--------------test image template_version: ", image_template_version) + test_case, final_prompt = self.read_image_example_data( + image_template_version + ) + tools_or_functions = ( + test_case["tools"] if "tools" in test_case else test_case["functions"] + ) + prompt_template = get_prompt_template_by_version(image_template_version) + created_prompt = prompt_template.get_prompt_from_messages( + test_case["messages"], tools_or_functions + ) + print(created_prompt) + self.assertEqual( + final_prompt.strip(), + created_prompt.strip(), + f"wrong final prompt for vision from: get_prompt_from_messages, for version={image_template_version}", + ) + def test_prepare_training_inputs_normal_tokenizer(self): - for template_version, pretrained_model in zip( - self.template_versions, self.pretrained_models - ): + for ( + template_version, + pretrained_model, + ) in self.template_version_to_model_name.items(): print(f"-------------_TEST: {template_version}, {pretrained_model}") self.run_prepare_training_inputs( template_version=template_version, @@ -152,16 +185,19 @@ def run_prepare_training_inputs( print(f"number of unmasked chunks: {len(chunks)}") for chunk, message in zip(chunks, assistant_message): + sys_msg = prompt_template.get_prompt_from_messages([]) if keep_assistant_prefix: prefix = "" else: - prefix = prompt_template.convert_message_to_prompt( - {"role": "assistant"} + prefix = prompt_template.get_prompt_from_messages( + [], add_generation_prompt=True ) + prefix = prefix[len(sys_msg) :].lstrip() decoded_content = prefix + tokenizer.decode( chunk ) # note that need to add: "\nassistant" because we mask this, see line 194 in prompt_utils.py - prompt = prompt_template.convert_message_to_prompt(message) + prompt = prompt_template.get_prompt_from_messages([message]) + prompt = prompt[len(sys_msg) :].lstrip() # decoded_content and prompt should be the same # to avoid any mistakes of tokenizer like adding white space we will compare after removing space self.assertEqual(