Skip to content

Commit

Permalink
Merge pull request #303 from MeetKai/refactor-v3.1-template
Browse files Browse the repository at this point in the history
Refactor v3.1 template
  • Loading branch information
jeffreymeetkai authored Jan 6, 2025
2 parents 2cba88c + 942196f commit dfd08d3
Showing 1 changed file with 58 additions and 31 deletions.
89 changes: 58 additions & 31 deletions functionary/prompt_template/llama31_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def initialize_fsm_gen_state(
"call_id": None, # call_id of the current tool
"first_chunk": True,
"first_function_chunk": True,
"text_to_func_buffer": [],
"clear_buffer": False,
"text_buffer": [],
"add_code_interpreter": add_code_interpreter,
}

Expand Down Expand Up @@ -187,33 +186,43 @@ def stream_delta_text(
prompt_utils.get_text_delta_response("", True, finish_reason)
)
gen_state["first_chunk"] = False
responses.append(
prompt_utils.get_text_delta_response(
gen_state["curr_text"], False, finish_reason
new_text = gen_state["curr_text"] + delta_text

if delta_text == "<|python_tag|>":
while gen_state["text_buffer"]:
responses.append(
prompt_utils.get_text_delta_response(
gen_state["text_buffer"][0], False, finish_reason
)
)
)
text_in_buffer = "".join(gen_state["text_to_func_buffer"] + [delta_text])
if delta_text != "<|python_tag|>" and not (
"<" in text_in_buffer
and "<function".startswith(text_in_buffer[text_in_buffer.index("<") :])
):
while len(gen_state["text_to_func_buffer"]) > 0:
delta_text_to_stream = gen_state["text_to_func_buffer"][0]
gen_state["text_buffer"] = gen_state["text_buffer"][1:]
elif "<" in delta_text:
pass
elif "<function" in new_text or "function=" in new_text:
# Stream whatever is before "<function"
new_text_from_buffer = "".join(gen_state["text_buffer"] + [delta_text])
if new_text_from_buffer.startswith("function="):
new_text_from_buffer = "<" + new_text_from_buffer
prefix = new_text_from_buffer[: new_text_from_buffer.index("<function")]
if prefix:
responses.append(
prompt_utils.get_text_delta_response(
delta_text_to_stream, False, finish_reason
prefix, False, finish_reason
)
)
gen_state["text_to_func_buffer"] = gen_state["text_to_func_buffer"][
1:
]
else:
while gen_state["text_buffer"]:
responses.append(
prompt_utils.get_text_delta_response(
gen_state["text_buffer"][0], False, finish_reason
)
)
gen_state["text_buffer"] = gen_state["text_buffer"][1:]
responses.append(
prompt_utils.get_text_delta_response(
delta_text, False, finish_reason
)
)
else:
gen_state["text_to_func_buffer"].append(delta_text)
elif gen_state["stage"] == "parameter":
if gen_state["first_function_chunk"]:
responses.append(
Expand Down Expand Up @@ -284,7 +293,7 @@ def update_fsm_gen_state(
self,
gen_state: Dict,
new_token: Optional[str],
new_token_id: Optional[str],
new_token_id: Optional[int],
options: Optional[List],
tokenizer: Any,
) -> Dict:
Expand All @@ -296,33 +305,50 @@ def update_fsm_gen_state(
gen_state["curr_text"] += new_token

if gen_state["stage"] == "pre-function":
if gen_state["curr_text"].startswith("<"):
if gen_state["curr_text"] == "<|python_tag|>":
self._update_gen_state_for_fn_call(
gen_state=gen_state, func_name="python"
)
gen_state["stage"] = "code-interpreter"
else:
gen_state["stage"] = "function"
# Always update text buffer in pre-function
if gen_state["curr_tokens"] is not None:
gen_state["text_buffer"].append(
tokenizer.decode(gen_state["curr_tokens"])[
len(gen_state["text_buffer"]) :
]
)
else:
gen_state["text_buffer"].append(new_token)

if gen_state["curr_text"] == "<|python_tag|>":
self._update_gen_state_for_fn_call(
gen_state=gen_state, func_name="python"
)
gen_state["stage"] = "code-interpreter"
elif gen_state["curr_text"].startswith("<function"):
gen_state["stage"] = "function"
elif gen_state["curr_text"] != "<":
gen_state["stage"] = "text-gen"
elif gen_state["stage"] == "text-gen":
if gen_state["curr_text"].endswith("<function"):
if "<function" in gen_state["curr_text"]:
gen_state["stage"] = "function"
gen_state["curr_text"] = "<function"
gen_state["curr_tokens"] = (
tokenizer.encode(gen_state["curr_text"], add_special_tokens=False)
if gen_state["curr_tokens"] is not None
else None
)
gen_state["text_to_func_buffer"] = []
gen_state["text_buffer"] = []
elif gen_state["curr_text"].endswith("<|python_tag|>"):
gen_state["stage"] = "code-interpreter"
gen_state = self._update_gen_state_for_fn_call(
gen_state=gen_state, func_name="python"
)
gen_state = self._reset_fsm_curr_text_and_tokens(gen_state=gen_state)
gen_state["text_to_func_buffer"] = []
gen_state["text_buffer"] = []
else:
# Add to text buffer if new token contains "<" in it
if new_token is not None and "<" in new_token:
gen_state["text_buffer"].append(new_token)
elif new_token_id is not None and "<" in tokenizer.decode(
[new_token_id]
):
gen_state["text_buffer"].append(tokenizer.decode([new_token_id]))
elif gen_state["stage"] == "function":
pattern = r"<function=[^>]+>"
match = re.search(pattern, gen_state["curr_text"])
Expand Down Expand Up @@ -350,6 +376,7 @@ def update_fsm_gen_state(
if gen_state["curr_tokens"] is not None
else None
)
gen_state["text_buffer"] = [gen_state["curr_text"]]

return gen_state

Expand Down

0 comments on commit dfd08d3

Please sign in to comment.