Skip to content

Commit

Permalink
Merge pull request #284 from MeetKai/migrate-vllm
Browse files Browse the repository at this point in the history
Migrate vllm
  • Loading branch information
jeffreymeetkai authored Nov 5, 2024
2 parents 46f08ef + a6ce1c8 commit 0231886
Show file tree
Hide file tree
Showing 6 changed files with 655 additions and 428 deletions.
20 changes: 20 additions & 0 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from copy import deepcopy

import jsonref
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

Expand Down Expand Up @@ -35,3 +38,20 @@ def analyze_tools_and_tool_choice(request):
tool_func_choice = "none"

return tools_or_functions, tool_func_choice


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
if tools:
for i in range(len(tools)):
if "type" in tools[i]:
if tools[i]["type"] == "function":
tools[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"])
)
else:
tools[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["parameters"])
)

return tools
3 changes: 2 additions & 1 deletion functionary/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ class StreamChoice(BaseModel):
finish_reason: Optional[str] = "stop"
index: int = 0


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class ChatCompletionChunk(BaseModel):
id: str
object: str = "chat.completion.chunk"
Expand Down Expand Up @@ -132,7 +134,6 @@ class ChatCompletionRequest(BaseModel):
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False

# @validator("tool_choice", always=True)
# def validate_tool_choice(cls, value, values):
Expand Down
6 changes: 5 additions & 1 deletion functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import json
import re
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import jinja2

from functionary.inference_utils import resolve_json_refs
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils

Expand Down Expand Up @@ -125,9 +127,11 @@ def get_prompt_from_messages(
str: the prompt for inference/training
"""

tools = resolve_json_refs(tools_or_functions=tools_or_functions)

prompt = self._jinja_template.render(
messages=messages,
tools=tools_or_functions,
tools=tools,
bos_token=bos_token,
add_generation_prompt=add_generation_prompt,
)
Expand Down
5 changes: 2 additions & 3 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ async def process_chat_completion(
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=False,
logprobs=logprobs,
)
Expand All @@ -207,7 +206,7 @@ async def process_chat_completion(

if enable_grammar_sampling:
result_generator = engine.generate(
inputs=TokensPrompt(prompt_token_ids=prompt_token_ids),
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
sampling_params=sampling_params,
request_id=request_id,
tools_or_functions=tools_or_functions,
Expand All @@ -216,7 +215,7 @@ async def process_chat_completion(
)
else:
result_generator = engine.generate(
inputs=TokensPrompt(prompt_token_ids=prompt_token_ids),
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
sampling_params=sampling_params,
request_id=request_id,
)
Expand Down
Loading

0 comments on commit 0231886

Please sign in to comment.