Skip to content

Commit

Permalink
fix json ref resolving
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreymeetkai committed Nov 5, 2024
1 parent 02ca951 commit ff40262
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
18 changes: 18 additions & 0 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from copy import deepcopy

import jsonref
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

Expand Down Expand Up @@ -35,3 +38,18 @@ def analyze_tools_and_tool_choice(request):
tool_func_choice = "none"

return tools_or_functions, tool_func_choice


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
for i in range(len(tools)):
if "type" in tools[i] and tools[i]["type"] == "function":
tools[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"])
)
else:
tools[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["parameters"])
)

return tools
16 changes: 3 additions & 13 deletions functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import jinja2
import jsonref

from functionary.inference_utils import resolve_json_refs
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils

Expand Down Expand Up @@ -127,21 +127,11 @@ def get_prompt_from_messages(
str: the prompt for inference/training
"""

for i in range(len(tools_or_functions)):
if "type" in tools_or_functions[i]:
tools_or_functions[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(
tools_or_functions[i]["function"]["parameters"]
)
)
else:
tools_or_functions[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools_or_functions[i]["parameters"])
)
tools = resolve_json_refs(tools_or_functions=tools_or_functions)

prompt = self._jinja_template.render(
messages=messages,
tools=tools_or_functions,
tools=tools,
bos_token=bos_token,
add_generation_prompt=add_generation_prompt,
)
Expand Down
6 changes: 4 additions & 2 deletions functionary/vllm_monkey_patch/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Expand All @@ -57,6 +56,7 @@
from functionary.inference import (
get_lm_format_enforcer_vllm_logits_processor_from_tool_name,
)
from functionary.inference_utils import resolve_json_refs
from functionary.openai_types import Tool

logger = init_logger(__name__)
Expand Down Expand Up @@ -368,16 +368,18 @@ async def step_async(
request_id = seq_group_metadata_list[i].request_id
gen_state = self.gen_states[request_id]
tools_or_functions = self.tools_or_functions[request_id]
tools = resolve_json_refs(tools_or_functions=tools_or_functions)

# Check if the model just transitioned to "parameter" or "pre-function"
if (
gen_state["stage"] == "parameter"
and seq_group_metadata_list[i].sampling_params.logits_processors is None
):

seq_group_metadata_list[i].sampling_params.logits_processors = [
await get_lm_format_enforcer_vllm_logits_processor_from_tool_name(
tool_name=gen_state["func_name"],
tools_or_functions=tools_or_functions,
tools_or_functions=tools,
tokenizer=tokenizer,
)
]
Expand Down

0 comments on commit ff40262

Please sign in to comment.