From 0c25249c52b1566d550618aa4828354907e7bf8e Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 5 Nov 2024 13:59:00 +0000 Subject: [PATCH] refactor tool_call to function_call; vllm health endpoint --- functionary/inference_utils.py | 35 ++++++++++++++++++++++++++++- functionary/sglang_inference.py | 39 ++++++--------------------------- functionary/vllm_inference.py | 34 +++++++++------------------- server_vllm.py | 7 ++++-- 4 files changed, 56 insertions(+), 59 deletions(-) diff --git a/functionary/inference_utils.py b/functionary/inference_utils.py index 8855223..89b5c35 100644 --- a/functionary/inference_utils.py +++ b/functionary/inference_utils.py @@ -1,6 +1,6 @@ from copy import deepcopy from http import HTTPStatus -from typing import Optional +from typing import Dict, List, Optional import jsonref import torch @@ -8,6 +8,7 @@ from pydantic import BaseModel from transformers import StoppingCriteria, StoppingCriteriaList +from functionary.openai_types import Function from functionary.prompt_template.prompt_utils import enforce_tool_choice @@ -128,3 +129,35 @@ def resolve_json_refs(tools_or_functions): ) return tools + + +def convert_tool_calls_to_function_call( + functions: Optional[List[Function]], chat_message: Dict +) -> Dict: + if "delta" not in chat_message: # Non-streaming + if ( + functions + and len(functions) > 0 + and "tool_calls" in chat_message + and chat_message["tool_calls"] is not None + and len(chat_message["tool_calls"]) > 0 + ): + chat_message["function_call"] = { + "name": chat_message["tool_calls"][0]["function"]["name"], + "arguments": chat_message["tool_calls"][0]["function"]["arguments"], + } + chat_message["tool_calls"] = None + else: # Streaming + if ( + functions + and len(functions) > 0 + and "tool_calls" in chat_message["delta"] + and chat_message["delta"]["tool_calls"] + and len(chat_message["delta"]["tool_calls"]) > 0 + ): + chat_message["delta"]["function_call"] = chat_message["delta"][ + "tool_calls" + ][0]["function"] + chat_message["delta"]["tool_calls"] = None + + return chat_message diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 0c65789..52fce26 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -41,6 +41,7 @@ from functionary.inference_utils import ( analyze_tools_and_tool_choice, check_all_errors, + convert_tool_calls_to_function_call, create_error_response, ) from functionary.openai_types import ( @@ -83,25 +84,6 @@ class ChatCompletionParams: grammar_sampling: bool -def convert_tool_calls_to_function_call( - functions: Optional[List[Function]], chat_message: Dict -) -> Dict: - if ( - functions - and len(functions) > 0 - and "tool_calls" in chat_message - and chat_message["tool_calls"] is not None - and len(chat_message["tool_calls"]) > 0 - ): - chat_message["function_call"] = { - "name": chat_message["tool_calls"][0]["function"]["name"], - "arguments": chat_message["tool_calls"][0]["function"]["arguments"], - } - chat_message["tool_calls"] = None - - return chat_message - - def v1_chat_generate_request( request: ChatCompletionRequest, tokenizer: AutoTokenizer, @@ -382,19 +364,12 @@ async def completion_stream_generator(params: ChatCompletionParams): params.tools_or_functions, ): # Convert tool_calls to function_call if request.functions is provided - if ( - params.request.functions - and len(params.request.functions) > 0 - and "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - tool_name = response["delta"]["tool_calls"][0]["function"]["name"] - tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] - response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ - "function" - ] - response["delta"]["tool_calls"] = None + response = convert_tool_calls_to_function_call( + functions=params.request.functions, chat_message=response + ) + if response["delta"]["function_call"]: + tool_name = response["delta"]["function_call"]["name"] + tool_args = response["delta"]["function_call"]["arguments"] if tool_name and len(tool_name) > 0 and tool_args == "": tool_call_count += 1 diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index 9989620..e5e1725 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -14,6 +14,7 @@ from functionary.inference_utils import ( analyze_tools_and_tool_choice, check_all_errors, + convert_tool_calls_to_function_call, create_error_response, ) from functionary.openai_types import ( @@ -193,19 +194,12 @@ async def completion_stream_generator( ): # Convert tool_calls to function_call if request.functions is provided - if ( - functions - and len(functions) > 0 - and "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - tool_name = response["delta"]["tool_calls"][0]["function"]["name"] - tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] - response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ - "function" - ] - response["delta"]["tool_calls"] = None + response = convert_tool_calls_to_function_call( + functions=request.functions, chat_message=response + ) + if response["delta"]["function_call"]: + tool_name = response["delta"]["function_call"]["name"] + tool_args = response["delta"]["function_call"]["arguments"] if tool_name and len(tool_name) > 0 and tool_args == "": tool_call_count += 1 # Return finish_reason after the first tool_call is streamed if functions is provided @@ -277,17 +271,9 @@ async def completion_stream_generator( ) # parse_generated_content(text_response) # Convert tool_calls to function_call if request.functions is provided - if ( - request.functions - and "tool_calls" in chat_mess - and chat_mess["tool_calls"] is not None - and len(chat_mess["tool_calls"]) > 0 - ): - chat_mess["function_call"] = { - "name": chat_mess["tool_calls"][0]["function"]["name"], - "arguments": chat_mess["tool_calls"][0]["function"]["arguments"], - } - chat_mess["tool_calls"] = None + chat_mess = convert_tool_calls_to_function_call( + functions=request.functions, chat_message=chat_mess + ) # Postprocess finish reason if tool_func_choice is None or tool_func_choice in ["auto", "required"]: diff --git a/server_vllm.py b/server_vllm.py index 53f394d..41e390e 100644 --- a/server_vllm.py +++ b/server_vllm.py @@ -27,8 +27,9 @@ import vllm.entrypoints.openai.api_server as vllm_api_server from fastapi import Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import Response from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.api_server import health, mount_metrics +from vllm.entrypoints.openai.api_server import mount_metrics from vllm.entrypoints.openai.protocol import ModelCard, ModelList, ModelPermission from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import get_tokenizer @@ -51,7 +52,9 @@ @app.get("/health") async def _health(): """Health check.""" - return await health() + # vLLM's OpenAI server's health check is too heavy and also requires + # creating engine_client here, so we just return 200 here. + return Response(status_code=200) @app.get("/v1/models")