From a5f45697d3394561f6ede7bfbe0010fc564beb14 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Fri, 16 Aug 2024 16:34:07 +0000 Subject: [PATCH 01/40] implement sglang server --- functionary/openai_types.py | 11 + functionary/sglang_inference.py | 525 ++++++++++++++++++++++++++++++++ requirements_sgl.txt | 4 + server_sglang.py | 437 ++++++++++++++++++++++++++ 4 files changed, 977 insertions(+) create mode 100644 functionary/sglang_inference.py create mode 100644 requirements_sgl.txt create mode 100644 server_sglang.py diff --git a/functionary/openai_types.py b/functionary/openai_types.py index abf4e52..9831445 100644 --- a/functionary/openai_types.py +++ b/functionary/openai_types.py @@ -128,12 +128,23 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None + + # Disable logprobs and top_logprobs currently first + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + # Additional parameters supported by vLLM best_of: Optional[int] = None top_k: Optional[int] = -1 ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + regex: Optional[str] = None + min_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + # @validator("tool_choice", always=True) # def validate_tool_choice(cls, value, values): # if value is None: diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py new file mode 100644 index 0000000..0cb6f1f --- /dev/null +++ b/functionary/sglang_inference.py @@ -0,0 +1,525 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Conversion between OpenAI APIs and native SRT APIs""" + +import asyncio +import json +import os +import time +import uuid +from http import HTTPStatus +from typing import Dict, List, Optional + +from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.openai_api.protocol import ( + BatchResponse, + ChatCompletionTokenLogprob, + ChoiceLogprobs, + DeltaMessage, + ErrorResponse, + FileResponse, + LogProbs, + TopLogprob, +) + +from functionary.inference_stream import generate_openai_format_from_stream_async +from functionary.inference_utils import analyze_tools_and_tool_choice +from functionary.openai_types import ( + ChatCompletionChunk, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + StreamChoice, + UsageInfo, +) +from functionary.prompt_template import get_prompt_template_from_tokenizer +from functionary.prompt_template.prompt_utils import prepare_messages_for_inference + + +class FileMetadata: + def __init__(self, filename: str, purpose: str): + self.filename = filename + self.purpose = purpose + + +# In-memory storage for batch jobs and files +batch_storage: Dict[str, BatchResponse] = {} +file_id_request: Dict[str, FileMetadata] = {} +file_id_response: Dict[str, FileResponse] = {} +# map file id to file path in SGLang backend +file_id_storage: Dict[str, str] = {} + + +# backend storage directory +storage_dir = None + + +def format_finish_reason(finish_reason) -> Optional[str]: + if finish_reason.startswith("None"): + return None + elif finish_reason.startswith("FINISH_MATCHED"): + return "stop" + elif finish_reason.startswith("FINISH_LENGTH"): + return "length" + elif finish_reason.startswith("FINISH_ABORT"): + return "abort" + else: + return "unknown" + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +): + error = ErrorResponse(message=message, type=err_type, code=status_code.value) + return JSONResponse(content=error.model_dump(), status_code=error.code) + + +def create_streaming_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +) -> str: + error = ErrorResponse(message=message, type=err_type, code=status_code.value) + json_str = json.dumps({"error": error.model_dump()}) + return json_str + + +def v1_chat_generate_request(all_requests, tokenizer_manager): + input_ids = [] + sampling_params_list = [] + image_data_list = [] + return_logprobs = [] + top_logprobs_nums = [] + for request in all_requests: + # Prep the data needed for the underlying GenerateReqInput: + # - prompt: The full prompt string. + # - stop: Custom stop tokens. + # - image_data: None or a list of image strings (URLs or base64 strings). + # None skips any image processing in GenerateReqInput. + tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice( + request=request + ) + if not isinstance(request.messages, str): + # Apply chat template and its stop strings. + prompt_ids = prepare_messages_for_inference( + tokenizer=tokenizer_manager.tokenizer, + messages=request.messages, + tools_or_functions=tools_or_functions, + tool_choice=tool_func_choice, + device="cpu", + ).tolist()[0] + stop = request.stop + image_data = None + else: + # Use the raw prompt and stop strings if the messages is already a string. + prompt_ids = request.messages + stop = request.stop + image_data = None + input_ids.append(prompt_ids) + return_logprobs.append(request.logprobs) + top_logprobs_nums.append(request.top_logprobs) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "n": request.n, + } + ) + image_data_list.append(image_data) + if len(all_requests) == 1: + input_ids = input_ids[0] + if isinstance(input_ids, str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} + sampling_params_list = sampling_params_list[0] + image_data = image_data_list[0] + return_logprobs = return_logprobs[0] + top_logprobs_nums = top_logprobs_nums[0] + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=image_data, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, + stream=all_requests[0].stream, + return_text_in_logprobs=True, + ) + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_chat_generate_response(request, prompt_template, ret): + choices = [] + + _, tool_func_choice = analyze_tools_and_tool_choice(request=request) + + for idx, ret_item in enumerate(ret): + logprobs = False + if isinstance(request, list) and request[idx].logprobs: + logprobs = True + elif (not isinstance(request, list)) and request.logprobs: + logprobs = True + if logprobs: + logprobs = to_openai_style_logprobs( + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + token_logprobs = [] + for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[0].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + choice_logprobs = ChoiceLogprobs(content=token_logprobs) + else: + choice_logprobs = None + + chat_mess = prompt_template.parse_assistant_response( + llm_output=ret_item["text"], tool_choice=tool_func_choice + ) + finish_reason = False + + # Convert tool_calls to function_call if request.functions is provided + if ( + request.functions + and "tool_calls" in chat_mess + and chat_mess["tool_calls"] is not None + and len(chat_mess["tool_calls"]) > 0 + ): + chat_mess["function_call"] = { + "name": chat_mess["tool_calls"][0]["function"]["name"], + "arguments": chat_mess["tool_calls"][0]["function"]["arguments"], + } + chat_mess["tool_calls"] = None + + # Postprocess finish reason + if "function_call" in chat_mess and chat_mess["function_call"]: + finish_reason = "function_call" + + if "tool_calls" in chat_mess and chat_mess["tool_calls"]: + finish_reason = "tool_calls" + + if finish_reason is None: + finish_reason = format_finish_reason(ret_item["meta_info"]["finish_reason"]) + + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage(**chat_mess), + # logprobs=choice_logprobs, + finish_reason=finish_reason, + ) + + choices.append(choice_data) + + prompt_tokens = sum( + ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) + ) + completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) + response = ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +async def v1_chat_completions(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + all_requests = [ChatCompletionRequest(**request_json)] + + prompt_template = get_prompt_template_from_tokenizer( + tokenizer=tokenizer_manager.tokenizer + ) + tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice( + all_requests[0] + ) + + adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) + + if adapted_request.stream: + + async def wrap_sgl_generator(): + stream_buffer = "" + async for content in tokenizer_manager.generate_request( + adapted_request, raw_request + ): + prompt_tokens = content["meta_info"]["prompt_tokens"] + completion_tokens = content["meta_info"]["completion_tokens"] + text = content["text"] + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + finish_reason = format_finish_reason( + content["meta_info"]["finish_reason"] + ) + + # If finish_reason is not None and delta_text is not empty, + # the delta_text is the eos_token and just remove it + if finish_reason is not None and len(delta) > 0: + delta = "" + yield delta, finish_reason + + async def completion_stream_generator(): + generator = wrap_sgl_generator() + + tool_call_count = 0 + async for response in generate_openai_format_from_stream_async( + generator, prompt_template, tool_func_choice, tools_or_functions + ): + # Convert tool_calls to function_call if request.functions is provided + if ( + request.functions + and len(request.functions) > 0 + and "tool_calls" in response["delta"] + and response["delta"]["tool_calls"] + and len(response["delta"]["tool_calls"]) > 0 + ): + tool_name = response["delta"]["tool_calls"][0]["function"]["name"] + tool_args = response["delta"]["tool_calls"][0]["function"][ + "arguments" + ] + response["delta"]["function_call"] = response["delta"][ + "tool_calls" + ][0]["function"] + response["delta"]["tool_calls"] = None + if tool_name and len(tool_name) > 0 and tool_args == "": + tool_call_count += 1 + # Return finish_reason after the first tool_call is streamed if functions is provided + if request.functions and tool_call_count == 2: + response["delta"] = {} + response["finish_reason"] = "function_call" + + chunk = StreamChoice(**response) + result = ChatCompletionChunk(id=adapted_request.rid, choices=[chunk]) + chunk_dic = result.dict(exclude_unset=True) + chunk_data = json.dumps(chunk_dic, ensure_ascii=False) + yield f"data: {chunk_data}\n\n" + # Break from for loop after the first tool_call is streamed if functions is provided + if request.functions and tool_call_count == 2: + break + yield "data: [DONE]\n\n" + + # async def generate_stream_resp(): + # is_first = True + + # stream_buffer = "" + # n_prev_token = 0 + # try: + # async for content in tokenizer_manager.generate_request( + # adapted_request, raw_request + # ): + # prompt_tokens = content["meta_info"]["prompt_tokens"] + # completion_tokens = content["meta_info"]["completion_tokens"] + # if request.logprobs: + # logprobs = to_openai_style_logprobs( + # output_token_logprobs=content["meta_info"][ + # "output_token_logprobs" + # ][n_prev_token:], + # output_top_logprobs=content["meta_info"][ + # "output_top_logprobs" + # ][n_prev_token:], + # ) + + # n_prev_token = len( + # content["meta_info"]["output_token_logprobs"] + # ) + # token_logprobs = [] + # for token, logprob in zip( + # logprobs.tokens, logprobs.token_logprobs + # ): + # token_bytes = list(token.encode("utf-8")) + # top_logprobs = [] + # if logprobs.top_logprobs: + # for top_token, top_logprob in logprobs.top_logprobs[ + # 0 + # ].items(): + # top_token_bytes = list(top_token.encode("utf-8")) + # top_logprobs.append( + # TopLogprob( + # token=top_token, + # bytes=top_token_bytes, + # logprob=top_logprob, + # ) + # ) + # token_logprobs.append( + # ChatCompletionTokenLogprob( + # token=token, + # bytes=token_bytes, + # logprob=logprob, + # top_logprobs=top_logprobs, + # ) + # ) + + # choice_logprobs = ChoiceLogprobs(content=token_logprobs) + + # else: + # choice_logprobs = None + + # if is_first: + # # First chunk with role + # is_first = False + # choice_data = ChatCompletionResponseStreamChoice( + # index=0, + # delta=DeltaMessage(role="assistant"), + # finish_reason=format_finish_reason( + # content["meta_info"]["finish_reason"] + # ), + # # logprobs=choice_logprobs, + # ) + # chunk = ChatCompletionStreamResponse( + # id=content["meta_info"]["id"], + # choices=[choice_data], + # model=request.model, + # ) + # yield f"data: {chunk.model_dump_json()}\n\n" + + # text = content["text"] + # delta = text[len(stream_buffer) :] + # stream_buffer = stream_buffer + delta + # choice_data = ChatCompletionResponseStreamChoice( + # index=0, + # delta=DeltaMessage(content=delta), + # finish_reason=format_finish_reason( + # content["meta_info"]["finish_reason"] + # ), + # logprobs=choice_logprobs, + # ) + # chunk = ChatCompletionStreamResponse( + # id=content["meta_info"]["id"], + # choices=[choice_data], + # model=request.model, + # ) + # yield f"data: {chunk.model_dump_json()}\n\n" + # if request.stream_options and request.stream_options.include_usage: + # usage = UsageInfo( + # prompt_tokens=prompt_tokens, + # completion_tokens=completion_tokens, + # total_tokens=prompt_tokens + completion_tokens, + # ) + + # final_usage_chunk = ChatCompletionStreamResponse( + # id=str(uuid.uuid4().hex), + # choices=[], + # model=request.model, + # usage=usage, + # ) + # final_usage_data = final_usage_chunk.model_dump_json( + # exclude_unset=True, exclude_none=True + # ) + # yield f"data: {final_usage_data}\n\n" + # except ValueError as e: + # error = create_streaming_error_response(str(e)) + # yield f"data: {error}\n\n" + # yield "data: [DONE]\n\n" + + return StreamingResponse( + # generate_stream_resp(), + completion_stream_generator(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request), + ) + + # Non-streaming response. + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + if not isinstance(ret, list): + ret = [ret] + + response = v1_chat_generate_response(request, prompt_template, ret) + + return response + + +def to_openai_style_logprobs( + input_token_logprobs=None, + output_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, +): + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): + for logprob, _, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) + + # Not supported yet + ret_logprobs.text_offset.append(-1) + + def append_top_logprobs(top_logprobs): + for tokens in top_logprobs: + if tokens is not None: + ret_logprobs.top_logprobs.append( + {token[2]: token[0] for token in tokens} + ) + else: + ret_logprobs.top_logprobs.append(None) + + if input_token_logprobs is not None: + append_token_logprobs(input_token_logprobs) + if output_token_logprobs is not None: + append_token_logprobs(output_token_logprobs) + if input_top_logprobs is not None: + append_top_logprobs(input_top_logprobs) + if output_top_logprobs is not None: + append_top_logprobs(output_top_logprobs) + + return ret_logprobs diff --git a/requirements_sgl.txt b/requirements_sgl.txt new file mode 100644 index 0000000..1f131da --- /dev/null +++ b/requirements_sgl.txt @@ -0,0 +1,4 @@ +jsonref~=1.1.0 +sglang[all]==0.2.13 +--find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ +flashinfer==0.1.5 diff --git a/server_sglang.py b/server_sglang.py new file mode 100644 index 0000000..e553206 --- /dev/null +++ b/server_sglang.py @@ -0,0 +1,437 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +The entry point of inference server. +SRT = SGLang Runtime. +""" + +import argparse +import asyncio +import dataclasses +import json +import logging +import multiprocessing as mp +import os +import sys +import threading +import time +from http import HTTPStatus +from typing import Dict, List, Optional, Union + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import requests +import uvicorn +import uvloop +from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi.responses import JSONResponse, Response, StreamingResponse +from sglang.srt.constrained import disable_cache +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.controller_multi import ( + start_controller_process as start_controller_process_multi, +) +from sglang.srt.managers.controller_single import launch_tp_servers +from sglang.srt.managers.controller_single import ( + start_controller_process as start_controller_process_single, +) +from sglang.srt.managers.detokenizer_manager import start_detokenizer_process +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.openai_api.adapter import ( + load_chat_template_for_openai_api, + v1_batches, + v1_delete_file, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, +) +from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + add_api_key_middleware, + allocate_init_ports, + assert_pkg_version, + enable_show_time_cost, + maybe_set_triton_cache_manager, + prepare_model, + prepare_tokenizer, + set_ulimit, +) +from sglang.utils import get_exception_traceback + +from functionary.sglang_inference import v1_chat_completions + +logger = logging.getLogger(__name__) + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +app = FastAPI() +tokenizer_manager = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.get("/get_model_info") +async def get_model_info(): + result = { + "model_path": tokenizer_manager.model_path, + "is_generation": tokenizer_manager.is_generation, + } + return result + + +@app.get("/get_server_args") +async def get_server_args(): + return dataclasses.asdict(tokenizer_manager.server_args) + + +@app.get("/flush_cache") +async def flush_cache(): + tokenizer_manager.flush_cache() + return Response( + content="Cache flushed.\nPlease check backend logs for more details. " + "(When there are running or waiting requests, the operation will not be performed.)\n", + status_code=200, + ) + + +async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" + if obj.stream: + + async def stream_results(): + try: + async for out in tokenizer_manager.generate_request(obj, request): + yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} + yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(obj), + ) + else: + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +app.post("/generate")(generate_request) +app.put("/generate")(generate_request) + + +@app.post("/v1/chat/completions") +async def openai_v1_chat_completions(raw_request: Request): + return await v1_chat_completions(tokenizer_manager, raw_request) + + +@app.get("/v1/models") +def available_models(): + """Show available models.""" + served_model_names = [tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, tokenizer_manager.server_args.file_storage_pth + ) + + +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(tokenizer_manager, raw_request) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + +def launch_server( + server_args: ServerArgs, + model_overide_args: Optional[dict] = None, + pipe_finish_writer: Optional[mp.connection.Connection] = None, +): + """Launch an HTTP server.""" + global tokenizer_manager + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports + server_args.port, server_args.additional_ports = allocate_init_ports( + server_args.port, + server_args.additional_ports, + server_args.dp_size, + ) + ports = server_args.additional_ports + port_args = PortArgs( + tokenizer_port=ports[0], + controller_port=ports[1], + detokenizer_port=ports[2], + nccl_ports=ports[3:], + ) + logger.info(f"{server_args=}") + + # Use model from www.modelscope.cn, first download the model. + server_args.model_path = prepare_model(server_args.model_path) + server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) + + # Launch processes for multi-node tensor parallelism + if server_args.nnodes > 1: + if server_args.node_rank != 0: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [ + i for _ in range(server_args.nnodes) for i in range(tp_size_local) + ] + tp_rank_range = list( + range( + server_args.node_rank * tp_size_local, + (server_args.node_rank + 1) * tp_size_local, + ) + ) + procs = launch_tp_servers( + gpu_ids, + tp_rank_range, + server_args, + ports[3], + model_overide_args, + ) + while True: + pass + + # Launch processes + tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) + pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) + + if server_args.dp_size == 1: + start_process = start_controller_process_single + else: + start_process = start_controller_process_multi + proc_controller = mp.Process( + target=start_process, + args=(server_args, port_args, pipe_controller_writer, model_overide_args), + ) + proc_controller.start() + proc_detoken = mp.Process( + target=start_detokenizer_process, + args=( + server_args, + port_args, + pipe_detoken_writer, + ), + ) + proc_detoken.start() + + # Wait for the model to finish loading + controller_init_state = pipe_controller_reader.recv() + detoken_init_state = pipe_detoken_reader.recv() + + if controller_init_state != "init ok" or detoken_init_state != "init ok": + proc_controller.kill() + proc_detoken.kill() + print( + f"Initialization failed. controller_init_state: {controller_init_state}", + flush=True, + ) + print( + f"Initialization failed. detoken_init_state: {detoken_init_state}", + flush=True, + ) + sys.exit(1) + assert proc_controller.is_alive() and proc_detoken.is_alive() + + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) + + # Send a warmup request + t = threading.Thread( + target=_wait_and_warmup, args=(server_args, pipe_finish_writer) + ) + t.start() + + # Listen for requests + try: + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + finally: + t.join() + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Set ulimit + set_ulimit() + + # Enable show time cost for debugging + if server_args.show_time_cost: + enable_show_time_cost() + + # Disable disk cache + if server_args.disable_disk_cache: + disable_cache() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if not server_args.disable_flashinfer: + assert_pkg_version( + "flashinfer", + "0.1.5", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + +def _wait_and_warmup(server_args, pipe_finish_writer): + headers = {} + url = server_args.url() + if server_args.api_key: + headers["Authorization"] = f"Bearer {server_args.api_key}" + + # Wait until the server is launched + success = False + for _ in range(120): + time.sleep(1) + try: + res = requests.get(url + "/get_model_info", timeout=5, headers=headers) + assert res.status_code == 200, f"{res}" + success = True + break + except (AssertionError, requests.exceptions.RequestException) as e: + last_traceback = get_exception_traceback() + pass + model_info = res.json() + + if not success: + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + print(f"Initialization failed. warmup error: {last_traceback}", flush=True) + sys.exit(1) + + # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 1 + json_data = { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + } + if server_args.skip_tokenizer_init: + json_data["input_ids"] = [10, 11, 12] + else: + json_data["text"] = "The capital city of France is" + + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + request_name, + json=json_data, + headers=headers, + timeout=600, + ) + assert res.status_code == 200, f"{res}" + except Exception as e: + last_traceback = get_exception_traceback() + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + print(f"Initialization failed. warmup error: {last_traceback}", flush=True) + sys.exit(1) + + # Print warnings here + if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None: + logger.warning( + "You set both `--disable-radix-cache` and `--chunked-prefill-size`. " + "This combination is an experimental feature and we noticed it can lead to " + "wrong generation results. If you want to use chunked prefill, it is recommended " + "not using `--disable-radix-cache`." + ) + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("init ok") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + + launch_server(server_args) From 5d993d9fdacc142f5f6ef76cb187970f51880249 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Sun, 18 Aug 2024 05:29:01 +0000 Subject: [PATCH 02/40] fix --- functionary/sglang_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 0cb6f1f..f157518 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -249,7 +249,7 @@ def v1_chat_generate_response(request, prompt_template, ret): if "tool_calls" in chat_mess and chat_mess["tool_calls"]: finish_reason = "tool_calls" - if finish_reason is None: + if not finish_reason: finish_reason = format_finish_reason(ret_item["meta_info"]["finish_reason"]) choice_data = ChatCompletionResponseChoice( From 432d7b36735c91c954e82461ee58b518082c7c75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Musab=20G=C3=BCltekin?= <87330355+musab-mk@users.noreply.github.com> Date: Wed, 21 Aug 2024 01:18:55 +0300 Subject: [PATCH 03/40] Update README.md --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 3748f1e..acafb61 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,15 @@ export VLLM_WORKER_MULTIPROC_METHOD=spawn python server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 ``` +
+ SGLang + +```shell +python server_sglang.py --model-path meetkai/functionary-medium-v3.2 --port 8000 --host 0.0.0.0 --tp 8 +``` + +
+ **Grammar Sampling** From 92a8347260ec0ce8ad8bf849dac5b97db1a05907 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Musab=20G=C3=BCltekin?= <87330355+musab-mk@users.noreply.github.com> Date: Wed, 21 Aug 2024 01:31:23 +0300 Subject: [PATCH 04/40] Update sglang_inference.py --- functionary/sglang_inference.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index f157518..b64f38f 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -249,6 +249,7 @@ def v1_chat_generate_response(request, prompt_template, ret): if "tool_calls" in chat_mess and chat_mess["tool_calls"]: finish_reason = "tool_calls" + if not finish_reason: finish_reason = format_finish_reason(ret_item["meta_info"]["finish_reason"]) @@ -342,7 +343,19 @@ async def completion_stream_generator(): if request.functions and tool_call_count == 2: response["delta"] = {} response["finish_reason"] = "function_call" + + # Workaround Fixes + response["delta"]["role"] = "assistant" + if ( + "tool_calls" in response["delta"] + and response["delta"]["tool_calls"] + and len(response["delta"]["tool_calls"]) > 0 + ): + for tool_call in response["delta"]["tool_calls"]: + if tool_call.get("type") is None: + tool_call["type"] = "function" + chunk = StreamChoice(**response) result = ChatCompletionChunk(id=adapted_request.rid, choices=[chunk]) chunk_dic = result.dict(exclude_unset=True) From 02dbf0b018f026c0336d18fe86b6dc47f3b0b8a4 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Wed, 21 Aug 2024 10:21:24 +0000 Subject: [PATCH 05/40] add <|eom_id|> for v3.1 stop --- functionary/sglang_inference.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index b64f38f..1fd618c 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -127,6 +127,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): device="cpu", ).tolist()[0] stop = request.stop + if ( + get_prompt_template_from_tokenizer( + tokenizer=tokenizer_manager.tokenizer + ).version + == "v3-llama3.1" + ): + stop.append("<|eom_id|>") image_data = None else: # Use the raw prompt and stop strings if the messages is already a string. @@ -249,7 +256,6 @@ def v1_chat_generate_response(request, prompt_template, ret): if "tool_calls" in chat_mess and chat_mess["tool_calls"]: finish_reason = "tool_calls" - if not finish_reason: finish_reason = format_finish_reason(ret_item["meta_info"]["finish_reason"]) @@ -343,7 +349,7 @@ async def completion_stream_generator(): if request.functions and tool_call_count == 2: response["delta"] = {} response["finish_reason"] = "function_call" - + # Workaround Fixes response["delta"]["role"] = "assistant" if ( @@ -355,7 +361,6 @@ async def completion_stream_generator(): if tool_call.get("type") is None: tool_call["type"] = "function" - chunk = StreamChoice(**response) result = ChatCompletionChunk(id=adapted_request.rid, choices=[chunk]) chunk_dic = result.dict(exclude_unset=True) From 5e643dd0ada1e1c6671ac2bd242ec2d54d4e094e Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Wed, 21 Aug 2024 10:53:55 +0000 Subject: [PATCH 06/40] refactor --- functionary/sglang_inference.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 1fd618c..32a27ba 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -126,14 +126,12 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): tool_choice=tool_func_choice, device="cpu", ).tolist()[0] - stop = request.stop - if ( - get_prompt_template_from_tokenizer( + stop = ( + request.stop + + get_prompt_template_from_tokenizer( tokenizer=tokenizer_manager.tokenizer - ).version - == "v3-llama3.1" - ): - stop.append("<|eom_id|>") + ).get_stop_tokens_for_generation() + ) image_data = None else: # Use the raw prompt and stop strings if the messages is already a string. From de331dc7c608a870a147760c22682546f51be3eb Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 22 Aug 2024 08:29:51 +0000 Subject: [PATCH 07/40] add logging to sglang server --- server_sglang.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/server_sglang.py b/server_sglang.py index e553206..b6a7cb1 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -74,6 +74,7 @@ ) from sglang.utils import get_exception_traceback +from functionary.openai_types import ChatCompletionRequest from functionary.sglang_inference import v1_chat_completions logger = logging.getLogger(__name__) @@ -149,6 +150,9 @@ async def stream_results(): @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): + if args.logfile is not None: + request_json = await raw_request.json() + logger.info(ChatCompletionRequest(**request_json).model_dump(mode="json")) return await v1_chat_completions(tokenizer_manager, raw_request) @@ -206,6 +210,8 @@ def launch_server( global tokenizer_manager logging.basicConfig( + filename=args.logfile, + filemode="a", level=getattr(logging, server_args.log_level.upper()), format="%(message)s", ) @@ -430,6 +436,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--logfile", type=str, default=None, help="name of the file to log requests" + ) ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) From ce74e994af8b9f6d0362149a7494e0342680a867 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 22 Aug 2024 09:39:42 +0000 Subject: [PATCH 08/40] improve logging --- .../sglang_monkey_patch/tokenizer_manager.py | 107 ++++++++++++++++++ server_sglang.py | 21 ++-- 2 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 functionary/sglang_monkey_patch/tokenizer_manager.py diff --git a/functionary/sglang_monkey_patch/tokenizer_manager.py b/functionary/sglang_monkey_patch/tokenizer_manager.py new file mode 100644 index 0000000..aa8d6fc --- /dev/null +++ b/functionary/sglang_monkey_patch/tokenizer_manager.py @@ -0,0 +1,107 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""TokenizerManager is a process that tokenizes the text.""" + +import asyncio +import dataclasses +import logging +from typing import Dict, List, Tuple, Union + +import uvloop +from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.server_args import PortArgs, ServerArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +logger = logging.getLogger("tokenizer_logger") + + +@dataclasses.dataclass +class ReqState: + """Store the state a request.""" + + out_list: List + finished: bool + event: asyncio.Event + + +class MonkeyPatchTokenizerManager(TokenizerManager): + """TokenizerManager is a process that tokenizes the text.""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args: Dict = None, + logfile: str = "logfile.txt", + ): + super().__init__(server_args, port_args, model_overide_args) + file_handler = logging.FileHandler(logfile) + logger.addHandler(file_handler) + + async def _wait_for_response( + self, + event: asyncio.Event, + state: ReqState, + obj: Union[GenerateReqInput, EmbeddingReqInput], + rid: str, + request, + index: int = None, + response_index: int = 0, + ): + while True: + try: + await asyncio.wait_for(event.wait(), timeout=4) + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + for rid in [obj.rid] if obj.is_single else obj.rid: + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue + + if self.is_generation: + out = self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob if index is None else obj.return_logprob[index], + ( + obj.top_logprobs_num + if index is None + else obj.top_logprobs_num[index] + ), + obj.return_text_in_logprobs, + ) + else: # isinstance(obj, EmbeddingReqInput) + out = state.out_list[-1] + + out["index"] = response_index + + # Log requests + # if self.server_args.log_requests and state.finished: + if state.finished: + if obj.text is None and obj.input_ids is not None: + obj.text = self.tokenizer.decode(obj.input_ids) + obj.input_ids = None + logger.info(dict(input=obj.__dict__, output=out)) + + state.out_list = [] + if state.finished: + del self.rid_to_state[rid] + yield out + break + + event.clear() + yield out diff --git a/server_sglang.py b/server_sglang.py index b6a7cb1..c41e8fb 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -74,8 +74,10 @@ ) from sglang.utils import get_exception_traceback -from functionary.openai_types import ChatCompletionRequest from functionary.sglang_inference import v1_chat_completions +from functionary.sglang_monkey_patch.tokenizer_manager import ( + MonkeyPatchTokenizerManager, +) logger = logging.getLogger(__name__) @@ -150,9 +152,6 @@ async def stream_results(): @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - if args.logfile is not None: - request_json = await raw_request.json() - logger.info(ChatCompletionRequest(**request_json).model_dump(mode="json")) return await v1_chat_completions(tokenizer_manager, raw_request) @@ -210,8 +209,6 @@ def launch_server( global tokenizer_manager logging.basicConfig( - filename=args.logfile, - filemode="a", level=getattr(logging, server_args.log_level.upper()), format="%(message)s", ) @@ -262,7 +259,12 @@ def launch_server( pass # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) + if args.logfile is not None: + tokenizer_manager = MonkeyPatchTokenizerManager( + server_args, port_args, model_overide_args, logfile=args.logfile + ) + else: + tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) @@ -437,7 +439,10 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--logfile", type=str, default=None, help="name of the file to log requests" + "--logfile", + type=str, + default=None, + help="enable detailed request input/output logging by providing logfile", ) ServerArgs.add_cli_args(parser) args = parser.parse_args() From 889566421c98079486e742a4aa7e11af44d9eb8f Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 22 Aug 2024 18:14:04 +0800 Subject: [PATCH 09/40] switch to RotatingFileHandler --- functionary/sglang_monkey_patch/tokenizer_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/functionary/sglang_monkey_patch/tokenizer_manager.py b/functionary/sglang_monkey_patch/tokenizer_manager.py index aa8d6fc..4d4a1eb 100644 --- a/functionary/sglang_monkey_patch/tokenizer_manager.py +++ b/functionary/sglang_monkey_patch/tokenizer_manager.py @@ -50,7 +50,9 @@ def __init__( logfile: str = "logfile.txt", ): super().__init__(server_args, port_args, model_overide_args) - file_handler = logging.FileHandler(logfile) + file_handler = logging.handlers.RotatingFileHandler( + logfile, maxBytes=1024 * 1024 * 100, backupCount=10 + ) logger.addHandler(file_handler) async def _wait_for_response( From f3bc1134ac1508bb7f20a6dfe883ece9c3816c07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Musab=20G=C3=BCltekin?= <87330355+musab-mk@users.noreply.github.com> Date: Fri, 23 Aug 2024 00:24:57 +0300 Subject: [PATCH 10/40] Update sglang_inference.py --- functionary/sglang_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 32a27ba..1691ec2 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -285,6 +285,10 @@ def v1_chat_generate_response(request, prompt_template, ret): async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() + for message in request_json["messages"]: + if message["role"] == "assistant" and message["content"] == "": + message["content"] = None + print(request_json) all_requests = [ChatCompletionRequest(**request_json)] prompt_template = get_prompt_template_from_tokenizer( From a17881ea20e907e894ffdfec26a5426e8d523f6d Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 12 Sep 2024 03:54:36 +0000 Subject: [PATCH 11/40] fix ChatCompletionChunk --- functionary/sglang_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 32a27ba..9e98f1c 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -360,7 +360,7 @@ async def completion_stream_generator(): tool_call["type"] = "function" chunk = StreamChoice(**response) - result = ChatCompletionChunk(id=adapted_request.rid, choices=[chunk]) + result = ChatCompletionChunk(id=adapted_request.rid, choices=[chunk], model=request.model) chunk_dic = result.dict(exclude_unset=True) chunk_data = json.dumps(chunk_dic, ensure_ascii=False) yield f"data: {chunk_data}\n\n" From 4c36143b13af36213513e9caa9d4349c89d08ced Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Mon, 30 Sep 2024 23:23:34 +0000 Subject: [PATCH 12/40] update server script --- server_sglang.py | 343 ++++++++++++++--------------------------------- 1 file changed, 98 insertions(+), 245 deletions(-) diff --git a/server_sglang.py b/server_sglang.py index c41e8fb..a380699 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -1,3 +1,20 @@ +# Adapted from +# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server.py + +# Copyright 2023-2024 SGLang Team + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -37,42 +54,23 @@ import requests import uvicorn import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from sglang.srt.constrained import disable_cache -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller_multi import ( - start_controller_process as start_controller_process_multi, -) -from sglang.srt.managers.controller_single import launch_tp_servers -from sglang.srt.managers.controller_single import ( - start_controller_process as start_controller_process_single, -) -from sglang.srt.managers.detokenizer_manager import start_detokenizer_process +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.openai_api.adapter import ( - load_chat_template_for_openai_api, - v1_batches, - v1_delete_file, - v1_files_create, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, -) +from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server import _set_envs_and_config, _wait_and_warmup from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, allocate_init_ports, - assert_pkg_version, - enable_show_time_cost, - maybe_set_triton_cache_manager, - prepare_model, - prepare_tokenizer, - set_ulimit, + configure_logger, + prepare_model_and_tokenizer, ) -from sglang.utils import get_exception_traceback from functionary.sglang_inference import v1_chat_completions from functionary.sglang_monkey_patch.tokenizer_manager import ( @@ -87,13 +85,36 @@ app = FastAPI() tokenizer_manager = None +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + @app.get("/health") async def health() -> Response: - """Health check.""" + """Check the health of the http server.""" return Response(status_code=200) +@app.get("/health_generate") +async def health_generate(request: Request) -> Response: + """Check the health of the inference server by generating one token.""" + gri = GenerateReqInput( + text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7} + ) + try: + async for _ in tokenizer_manager.generate_request(gri, request): + break + return Response(status_code=200) + except Exception as e: + logger.exception(e) + return Response(status_code=503) + + @app.get("/get_model_info") async def get_model_info(): result = { @@ -165,58 +186,19 @@ def available_models(): return ModelList(data=model_cards) -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) - - -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - def launch_server( server_args: ServerArgs, - model_overide_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" global tokenizer_manager - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) - + # Configure global environment + configure_logger(server_args) server_args.check_server_args() _set_envs_and_config(server_args) - # Allocate ports + # Allocate ports for inter-process communications server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, @@ -225,87 +207,63 @@ def launch_server( ports = server_args.additional_ports port_args = PortArgs( tokenizer_port=ports[0], - controller_port=ports[1], + scheduler_port=ports[1], detokenizer_port=ports[2], nccl_ports=ports[3:], ) logger.info(f"{server_args=}") - # Use model from www.modelscope.cn, first download the model. - server_args.model_path = prepare_model(server_args.model_path) - server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) - - # Launch processes for multi-node tensor parallelism - if server_args.nnodes > 1: - if server_args.node_rank != 0: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [ - i for _ in range(server_args.nnodes) for i in range(tp_size_local) - ] - tp_rank_range = list( - range( - server_args.node_rank * tp_size_local, - (server_args.node_rank + 1) * tp_size_local, - ) - ) - procs = launch_tp_servers( - gpu_ids, - tp_rank_range, - server_args, - ports[3], - model_overide_args, - ) - while True: - pass + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) - # Launch processes - if args.logfile is not None: - tokenizer_manager = MonkeyPatchTokenizerManager( - server_args, port_args, model_overide_args, logfile=args.logfile + # Launch tensor parallel scheduler processes + scheduler_procs = [] + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, writer), ) - else: - tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) - pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + + if server_args.node_rank >= 1: + # For other nodes, they do not need to run tokenizer or detokenizer, + # so they can just wait here. + while True: + pass - if server_args.dp_size == 1: - start_process = start_controller_process_single - else: - start_process = start_controller_process_multi - proc_controller = mp.Process( - target=start_process, - args=(server_args, port_args, pipe_controller_writer, model_overide_args), - ) - proc_controller.start() - proc_detoken = mp.Process( - target=start_detokenizer_process, + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, args=( server_args, port_args, - pipe_detoken_writer, ), ) - proc_detoken.start() - - # Wait for the model to finish loading - controller_init_state = pipe_controller_reader.recv() - detoken_init_state = pipe_detoken_reader.recv() - - if controller_init_state != "init ok" or detoken_init_state != "init ok": - proc_controller.kill() - proc_detoken.kill() - print( - f"Initialization failed. controller_init_state: {controller_init_state}", - flush=True, - ) - print( - f"Initialization failed. detoken_init_state: {detoken_init_state}", - flush=True, - ) - sys.exit(1) - assert proc_controller.is_alive() and proc_detoken.is_alive() + detoken_proc.start() + + # Launch tokenizer process + if args.logfile is not None: + tokenizer_manager = MonkeyPatchTokenizerManager(server_args, port_args) + else: + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + + # Wait for model to finish loading + for i in range(len(scheduler_pipe_readers)): + scheduler_pipe_readers[i].recv() # Add api key authorization if server_args.api_key: @@ -313,12 +271,12 @@ def launch_server( # Send a warmup request t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) + target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid()) ) t.start() - # Listen for requests try: + # Listen for HTTP requests uvicorn.run( app, host=server_args.host, @@ -331,111 +289,6 @@ def launch_server( t.join() -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - - # Set ulimit - set_ulimit() - - # Enable show time cost for debugging - if server_args.show_time_cost: - enable_show_time_cost() - - # Disable disk cache - if server_args.disable_disk_cache: - disable_cache() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if not server_args.disable_flashinfer: - assert_pkg_version( - "flashinfer", - "0.1.5", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - -def _wait_and_warmup(server_args, pipe_finish_writer): - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res}" - success = True - break - except (AssertionError, requests.exceptions.RequestException) as e: - last_traceback = get_exception_traceback() - pass - model_info = res.json() - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - max_new_tokens = 8 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [10, 11, 12] - else: - json_data["text"] = "The capital city of France is" - - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=600, - ) - assert res.status_code == 200, f"{res}" - except Exception as e: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) - - # Print warnings here - if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None: - logger.warning( - "You set both `--disable-radix-cache` and `--chunked-prefill-size`. " - "This combination is an experimental feature and we noticed it can lead to " - "wrong generation results. If you want to use chunked prefill, it is recommended " - "not using `--disable-radix-cache`." - ) - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("init ok") - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( From 6c54bc43da1533a0e36281b44db53491cfefc21f Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Mon, 30 Sep 2024 23:42:17 +0000 Subject: [PATCH 13/40] update logging monkey patch --- .../sglang_monkey_patch/tokenizer_manager.py | 37 ++++++++----------- server_sglang.py | 2 +- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/functionary/sglang_monkey_patch/tokenizer_manager.py b/functionary/sglang_monkey_patch/tokenizer_manager.py index 4d4a1eb..25f06a1 100644 --- a/functionary/sglang_monkey_patch/tokenizer_manager.py +++ b/functionary/sglang_monkey_patch/tokenizer_manager.py @@ -16,13 +16,17 @@ """TokenizerManager is a process that tokenizes the text.""" import asyncio -import dataclasses import logging -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union +import fastapi import uvloop -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput -from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + RewardReqInput, +) +from sglang.srt.managers.tokenizer_manager import ReqState, TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -30,15 +34,6 @@ logger = logging.getLogger("tokenizer_logger") -@dataclasses.dataclass -class ReqState: - """Store the state a request.""" - - out_list: List - finished: bool - event: asyncio.Event - - class MonkeyPatchTokenizerManager(TokenizerManager): """TokenizerManager is a process that tokenizes the text.""" @@ -46,10 +41,9 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_overide_args: Dict = None, logfile: str = "logfile.txt", ): - super().__init__(server_args, port_args, model_overide_args) + super().__init__(server_args, port_args) file_handler = logging.handlers.RotatingFileHandler( logfile, maxBytes=1024 * 1024 * 100, backupCount=10 ) @@ -57,17 +51,16 @@ def __init__( async def _wait_for_response( self, - event: asyncio.Event, state: ReqState, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], rid: str, - request, - index: int = None, + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, response_index: int = 0, ): while True: try: - await asyncio.wait_for(event.wait(), timeout=4) + await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): for rid in [obj.rid] if obj.is_single else obj.rid: @@ -86,7 +79,7 @@ async def _wait_for_response( ), obj.return_text_in_logprobs, ) - else: # isinstance(obj, EmbeddingReqInput) + else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput)) out = state.out_list[-1] out["index"] = response_index @@ -105,5 +98,5 @@ async def _wait_for_response( yield out break - event.clear() + state.event.clear() yield out diff --git a/server_sglang.py b/server_sglang.py index a380699..2ff696d 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -255,7 +255,7 @@ def launch_server( # Launch tokenizer process if args.logfile is not None: - tokenizer_manager = MonkeyPatchTokenizerManager(server_args, port_args) + tokenizer_manager = MonkeyPatchTokenizerManager(server_args, port_args, args.logfile) else: tokenizer_manager = TokenizerManager(server_args, port_args) if server_args.chat_template: From 43f9c99b09f0a4523cac648fd15325b906597033 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 1 Oct 2024 01:48:04 +0000 Subject: [PATCH 14/40] set up frontend lang runtime in server_sglang --- server_sglang.py | 110 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 4 deletions(-) diff --git a/server_sglang.py b/server_sglang.py index 2ff696d..b5d3a8a 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -42,6 +42,7 @@ import logging import multiprocessing as mp import os +import socket import sys import threading import time @@ -51,19 +52,20 @@ # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) -import requests +import sglang as sgl import uvicorn import uvloop from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server import _set_envs_and_config, _wait_and_warmup +from sglang.srt.server import Runtime, _set_envs_and_config, _wait_and_warmup from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, @@ -255,7 +257,9 @@ def launch_server( # Launch tokenizer process if args.logfile is not None: - tokenizer_manager = MonkeyPatchTokenizerManager(server_args, port_args, args.logfile) + tokenizer_manager = MonkeyPatchTokenizerManager( + server_args, port_args, args.logfile + ) else: tokenizer_manager = TokenizerManager(server_args, port_args) if server_args.chat_template: @@ -289,6 +293,79 @@ def launch_server( t.join() +def find_free_port(exclude_port: int) -> int: + """ + This function finds a free port that is not the excluded port. + + Args: + exclude_port (int): The port number to exclude from selection. + + Returns: + int: A free port number that is not the excluded port. + """ + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + if s.getsockname()[1] != exclude_port: + return s.getsockname()[1] + except socket.error: + continue + + +class FunctionaryRuntime(Runtime): + """ + A wrapper for the server. + This is used for launching the server in a python program without + using the commond line interface. + """ + + def __init__( + self, + log_level: str = "error", + *args, + **kwargs, + ): + """See the arguments in server_args.py::ServerArgs""" + self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) + + # Pre-allocate ports + self.server_args.port, self.server_args.additional_ports = allocate_init_ports( + self.server_args.port, + self.server_args.additional_ports, + self.server_args.dp_size, + ) + + self.url = self.server_args.url() + self.generate_url = ( + f"http://{self.server_args.host}:{self.server_args.port}/generate" + ) + + self.pid = None + pipe_reader, pipe_writer = mp.Pipe(duplex=False) + + proc = mp.Process( + target=launch_server, + args=(self.server_args, pipe_writer), + ) + proc.start() + pipe_writer.close() + self.pid = proc.pid + + try: + init_state = pipe_reader.recv() + except EOFError: + init_state = "" + + if init_state != "ready": + self.shutdown() + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + + self.endpoint = RuntimeEndpoint(self.url) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -297,8 +374,33 @@ def launch_server( default=None, help="enable detailed request input/output logging by providing logfile", ) + parser.add_argument( + "--enable-grammar-sampling", + dest="grammar_sampling", + action="store_true", + default=False, + help="enable grammar sampling for function names", + ) ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args) + if args.grammar_sampling: + wrapper_port = server_args.port + # Find a new random free port for the backend server runtime + server_args.port = find_free_port(exclude_port=wrapper_port) + backend = FunctionaryRuntime(**vars(server_args)) + sgl.set_default_backend( + sgl.RuntimeEndpoint(f"http://{server_args.host}:{server_args.port}") + ) + uvicorn.run( + app, + host=server_args.host, + port=wrapper_port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + backend.shutdown() + else: + launch_server(server_args) From 03f737ca1d414c7b491ab04d04b8966aaa4826ec Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 1 Oct 2024 12:44:02 +0000 Subject: [PATCH 15/40] wip --- functionary/prompt_template/prompt_utils.py | 6 + functionary/sglang_inference.py | 246 ++++++++++---------- server_sglang.py | 10 +- 3 files changed, 141 insertions(+), 121 deletions(-) diff --git a/functionary/prompt_template/prompt_utils.py b/functionary/prompt_template/prompt_utils.py index 600ff98..639d07a 100644 --- a/functionary/prompt_template/prompt_utils.py +++ b/functionary/prompt_template/prompt_utils.py @@ -59,6 +59,7 @@ def prepare_messages_for_inference( messages: List[ChatMessage], tools_or_functions: List[Dict], tool_choice: Optional[Union[str, Tool, Function]] = None, + return_text: bool = False, device="cuda:0", ) -> torch.Tensor: """This function receives the messages and generates the final prompt tokenized by the @@ -69,6 +70,7 @@ def prepare_messages_for_inference( messages (List[ChatMessage]): The list of messages for the conversation tools_or_functions (List[Dict]): list of tools or functions tool_choice (Optional[Union[str, Tool, Function]], optional): tool_choice provided by the user. Defaults to None. + return_text (bool, optional): whether to return the text of the prompt. Defaults to False. device (str, optional): device for the tokenized tensor. Defaults to "cuda:0". Returns: @@ -95,6 +97,10 @@ def prepare_messages_for_inference( # add prefix based on tool-choice final_prompt += prompt_template.get_generation_prefix_for_tool_choice(tool_choice) + + if return_text: + return final_prompt + input_ids = tokenizer(final_prompt, return_tensors="pt").input_ids input_ids = input_ids.to(device) return input_ids diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 9e98f1c..ab29ff2 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -21,10 +21,14 @@ import time import uuid from http import HTTPStatus -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional +import sglang as sgl from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse +from outlines.fsm.json_schema import build_regex_from_schema +from sglang.lang.choices import greedy_token_selection +from sglang.lang.interpreter import ProgramState from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.openai_api.protocol import ( BatchResponse, @@ -70,6 +74,10 @@ def __init__(self, filename: str, purpose: str): storage_dir = None +# Choices sampling method for sgl.select +CHOICES_SAMPLING_METHOD = greedy_token_selection + + def format_finish_reason(finish_reason) -> Optional[str]: if finish_reason.startswith("None"): return None @@ -102,7 +110,7 @@ def create_streaming_error_response( return json_str -def v1_chat_generate_request(all_requests, tokenizer_manager): +def v1_chat_generate_request(all_requests, tokenizer): input_ids = [] sampling_params_list = [] image_data_list = [] @@ -120,7 +128,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): if not isinstance(request.messages, str): # Apply chat template and its stop strings. prompt_ids = prepare_messages_for_inference( - tokenizer=tokenizer_manager.tokenizer, + tokenizer=tokenizer, messages=request.messages, tools_or_functions=tools_or_functions, tool_choice=tool_func_choice, @@ -129,7 +137,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): stop = ( request.stop + get_prompt_template_from_tokenizer( - tokenizer=tokenizer_manager.tokenizer + tokenizer=tokenizer ).get_stop_tokens_for_generation() ) image_data = None @@ -286,6 +294,7 @@ def v1_chat_generate_response(request, prompt_template, ret): async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() all_requests = [ChatCompletionRequest(**request_json)] + tokenizer = tokenizer_manager.tokenizer prompt_template = get_prompt_template_from_tokenizer( tokenizer=tokenizer_manager.tokenizer @@ -294,7 +303,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): all_requests[0] ) - adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) + adapted_request, request = v1_chat_generate_request(all_requests, tokenizer) if adapted_request.stream: @@ -360,7 +369,9 @@ async def completion_stream_generator(): tool_call["type"] = "function" chunk = StreamChoice(**response) - result = ChatCompletionChunk(id=adapted_request.rid, choices=[chunk], model=request.model) + result = ChatCompletionChunk( + id=adapted_request.rid, choices=[chunk], model=request.model + ) chunk_dic = result.dict(exclude_unset=True) chunk_data = json.dumps(chunk_dic, ensure_ascii=False) yield f"data: {chunk_data}\n\n" @@ -369,119 +380,6 @@ async def completion_stream_generator(): break yield "data: [DONE]\n\n" - # async def generate_stream_resp(): - # is_first = True - - # stream_buffer = "" - # n_prev_token = 0 - # try: - # async for content in tokenizer_manager.generate_request( - # adapted_request, raw_request - # ): - # prompt_tokens = content["meta_info"]["prompt_tokens"] - # completion_tokens = content["meta_info"]["completion_tokens"] - # if request.logprobs: - # logprobs = to_openai_style_logprobs( - # output_token_logprobs=content["meta_info"][ - # "output_token_logprobs" - # ][n_prev_token:], - # output_top_logprobs=content["meta_info"][ - # "output_top_logprobs" - # ][n_prev_token:], - # ) - - # n_prev_token = len( - # content["meta_info"]["output_token_logprobs"] - # ) - # token_logprobs = [] - # for token, logprob in zip( - # logprobs.tokens, logprobs.token_logprobs - # ): - # token_bytes = list(token.encode("utf-8")) - # top_logprobs = [] - # if logprobs.top_logprobs: - # for top_token, top_logprob in logprobs.top_logprobs[ - # 0 - # ].items(): - # top_token_bytes = list(top_token.encode("utf-8")) - # top_logprobs.append( - # TopLogprob( - # token=top_token, - # bytes=top_token_bytes, - # logprob=top_logprob, - # ) - # ) - # token_logprobs.append( - # ChatCompletionTokenLogprob( - # token=token, - # bytes=token_bytes, - # logprob=logprob, - # top_logprobs=top_logprobs, - # ) - # ) - - # choice_logprobs = ChoiceLogprobs(content=token_logprobs) - - # else: - # choice_logprobs = None - - # if is_first: - # # First chunk with role - # is_first = False - # choice_data = ChatCompletionResponseStreamChoice( - # index=0, - # delta=DeltaMessage(role="assistant"), - # finish_reason=format_finish_reason( - # content["meta_info"]["finish_reason"] - # ), - # # logprobs=choice_logprobs, - # ) - # chunk = ChatCompletionStreamResponse( - # id=content["meta_info"]["id"], - # choices=[choice_data], - # model=request.model, - # ) - # yield f"data: {chunk.model_dump_json()}\n\n" - - # text = content["text"] - # delta = text[len(stream_buffer) :] - # stream_buffer = stream_buffer + delta - # choice_data = ChatCompletionResponseStreamChoice( - # index=0, - # delta=DeltaMessage(content=delta), - # finish_reason=format_finish_reason( - # content["meta_info"]["finish_reason"] - # ), - # logprobs=choice_logprobs, - # ) - # chunk = ChatCompletionStreamResponse( - # id=content["meta_info"]["id"], - # choices=[choice_data], - # model=request.model, - # ) - # yield f"data: {chunk.model_dump_json()}\n\n" - # if request.stream_options and request.stream_options.include_usage: - # usage = UsageInfo( - # prompt_tokens=prompt_tokens, - # completion_tokens=completion_tokens, - # total_tokens=prompt_tokens + completion_tokens, - # ) - - # final_usage_chunk = ChatCompletionStreamResponse( - # id=str(uuid.uuid4().hex), - # choices=[], - # model=request.model, - # usage=usage, - # ) - # final_usage_data = final_usage_chunk.model_dump_json( - # exclude_unset=True, exclude_none=True - # ) - # yield f"data: {final_usage_data}\n\n" - # except ValueError as e: - # error = create_streaming_error_response(str(e)) - # yield f"data: {error}\n\n" - # yield "data: [DONE]\n\n" - return StreamingResponse( # generate_stream_resp(), completion_stream_generator(), @@ -504,6 +402,116 @@ async def completion_stream_generator(): return response +async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): + request_json = await raw_request.json() + request = ChatCompletionRequest(**request_json) + tokenizer = backend.get_tokenizer() + + prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer) + tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request) + + gen_state = prompt_template.initialize_fsm_gen_state( + tool_choice=tool_func_choice, + curr_text="", + curr_tokens=None, + add_code_interpreter=( + True + if any( + [ + "type" in tool_or_func + and tool_or_func["type"] == "code_interpreter" + for tool_or_func in tools_or_functions + ] + ) + else False + ), + ) + + @sgl.function + def generate_response(s: ProgramState, gen_state: Dict): + s += prepare_messages_for_inference( + tokenizer=tokenizer, + messages=request.messages, + tools_or_functions=tools_or_functions, + tool_choice=tool_func_choice, + return_text=True, + ) + + # Form the options for the following stages + tools = [] + for tool in tools_or_functions: + if "type" in tool: + if tool["type"] == "function": + tools.append(tool["function"]) + else: + tools.append(tool) + options = prompt_template.get_options_from_gen_state( + gen_state=gen_state, tools_or_functions=tools + ) + + recipient_idx = 0 + recipient_var = f"recipient_{recipient_idx}" + content_var = f"content_{recipient_idx}" + while True: + if gen_state["stage"] == "function": + choices = [ + tool["function"]["name"] if "function" in tool else tool["name"] + for tool in tools_or_functions + ] + if gen_state["add_all_recipient"]: + choices.append("all") + s += sgl.select( + name=recipient_var, + choices=choices, + choices_method=CHOICES_SAMPLING_METHOD, + ) + new_token = s[recipient_var] + elif gen_state["stage"] == "pre-parameter": + s += prompt_template.fn_param_sep_token + new_token = prompt_template.fn_param_sep_token + elif gen_state["stage"] == "parameter": + tool = next(t for t in tools if t["name"] == gen_state["func_name"]) + regex = build_regex_from_schema(json.dumps(tool["parameters"])) + s += sgl.gen(name=content_var, regex=regex) + elif gen_state["stage"] == "text-gen": + s += sgl.gen( + name=content_var, + stop=[prompt_template.get_start_of_function_call_token()] + + prompt_template.get_stop_tokens_for_generation(), + ) + elif gen_state["stage"] == "pre-function": + s += sgl.gen( + name=content_var, + stop=[prompt_template.get_start_of_function_call_token()] + + prompt_template.get_stop_tokens_for_generation(), + ) + + if content_var in s: + stop_match = s.get_meta_info(content_var)["finish_reason"]["matched"] + if not isinstance(stop_match, str): + stop_match = tokenizer.decode(stop_match) + if stop_match in prompt_template.get_stop_tokens_for_generation(): + break + else: + gen_state["stage"] = "pre-function" + gen_state["curr_text"] = ( + prompt_template.get_start_of_function_call_token() + ) + new_token = prompt_template.get_start_of_function_call_token() + + gen_state = prompt_template.update_fsm_gen_state( + gen_state=gen_state, + new_token=new_token, + new_token_id=None, + options=options, + tokenizer=tokenizer, + ) + + state = generate_response.run(gen_state=gen_state) + breakpoint() + # text_response = + + def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, diff --git a/server_sglang.py b/server_sglang.py index b5d3a8a..5ae7d83 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -74,7 +74,10 @@ prepare_model_and_tokenizer, ) -from functionary.sglang_inference import v1_chat_completions +from functionary.sglang_inference import ( + v1_chat_completions, + v1_chat_completions_grammar_sampling, +) from functionary.sglang_monkey_patch.tokenizer_manager import ( MonkeyPatchTokenizerManager, ) @@ -175,7 +178,10 @@ async def stream_results(): @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) + if args.grammar_sampling: + return await v1_chat_completions_grammar_sampling(backend, raw_request) + else: + return await v1_chat_completions(tokenizer_manager, raw_request) @app.get("/v1/models") From 2d94725f742cd82d4bb5e9bb71fe9ad1ed0c00de Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Wed, 2 Oct 2024 13:56:15 +0000 Subject: [PATCH 16/40] non-streaming --- functionary/sglang_inference.py | 162 ++++++++++++++++++++++++-------- server_sglang.py | 6 +- 2 files changed, 128 insertions(+), 40 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index ab29ff2..f28d3d0 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -18,6 +18,7 @@ import asyncio import json import os +import re import time import uuid from http import HTTPStatus @@ -49,7 +50,9 @@ ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, + Function, StreamChoice, + Tool, UsageInfo, ) from functionary.prompt_template import get_prompt_template_from_tokenizer @@ -407,6 +410,22 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): request = ChatCompletionRequest(**request_json) tokenizer = backend.get_tokenizer() + # Convert legacy functions to tools + if request.functions is not None: + request.tools = [ + Tool(type="function", function=function) for function in request.functions + ] + # Convert legacy function_call to tool_choice + if request.function_call is not None: + if isinstance(request.function_call, str) and ( + request.function_call == "none" or request.function_call == "auto" + ): + request.tool_choice = request.function_call + if request.function_call and isinstance(request.function_call, Function): + request.tool_choice = Tool( + type="function", function=Function(name=request.function_call.name) + ) + prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer) tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request) @@ -426,16 +445,20 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): else False ), ) + prompt = prepare_messages_for_inference( + tokenizer=tokenizer, + messages=request.messages, + tools_or_functions=tools_or_functions, + tool_choice=tool_func_choice, + return_text=True, + ) + + recipient_var = "recipient" + content_var = "content" @sgl.function def generate_response(s: ProgramState, gen_state: Dict): - s += prepare_messages_for_inference( - tokenizer=tokenizer, - messages=request.messages, - tools_or_functions=tools_or_functions, - tool_choice=tool_func_choice, - return_text=True, - ) + s += prompt # Form the options for the following stages tools = [] @@ -449,17 +472,26 @@ def generate_response(s: ProgramState, gen_state: Dict): gen_state=gen_state, tools_or_functions=tools ) - recipient_idx = 0 - recipient_var = f"recipient_{recipient_idx}" - content_var = f"content_{recipient_idx}" + stop_tokens = prompt_template.get_stop_tokens_for_generation() + function_call_token = prompt_template.get_start_of_function_call_token() + + def check_stop_condition(): + stop_match = s.get_meta_info(content_var)["finish_reason"]["matched"] + if not isinstance(stop_match, str): + stop_match = tokenizer.decode(stop_match) + return stop_match in stop_tokens + while True: if gen_state["stage"] == "function": choices = [ - tool["function"]["name"] if "function" in tool else tool["name"] + tool["function"]["name"] for tool in tools_or_functions + if tool["type"] == "function" ] if gen_state["add_all_recipient"]: choices.append("all") + if gen_state["add_code_interpreter"]: + choices.append("python") s += sgl.select( name=recipient_var, choices=choices, @@ -471,33 +503,31 @@ def generate_response(s: ProgramState, gen_state: Dict): new_token = prompt_template.fn_param_sep_token elif gen_state["stage"] == "parameter": tool = next(t for t in tools if t["name"] == gen_state["func_name"]) - regex = build_regex_from_schema(json.dumps(tool["parameters"])) - s += sgl.gen(name=content_var, regex=regex) - elif gen_state["stage"] == "text-gen": - s += sgl.gen( - name=content_var, - stop=[prompt_template.get_start_of_function_call_token()] - + prompt_template.get_stop_tokens_for_generation(), + regex = ( + build_regex_from_schema(json.dumps(tool["parameters"])) + + f"({re.escape(function_call_token)})?" ) - elif gen_state["stage"] == "pre-function": - s += sgl.gen( - name=content_var, - stop=[prompt_template.get_start_of_function_call_token()] - + prompt_template.get_stop_tokens_for_generation(), - ) - - if content_var in s: - stop_match = s.get_meta_info(content_var)["finish_reason"]["matched"] - if not isinstance(stop_match, str): - stop_match = tokenizer.decode(stop_match) - if stop_match in prompt_template.get_stop_tokens_for_generation(): + s += sgl.gen(name=content_var, regex=regex, stop=function_call_token) + new_token = s[content_var] + if check_stop_condition(): + break + elif gen_state["stage"] == "text-gen": + s += sgl.gen(name=content_var, stop=function_call_token) + if check_stop_condition(): break else: - gen_state["stage"] = "pre-function" - gen_state["curr_text"] = ( - prompt_template.get_start_of_function_call_token() - ) - new_token = prompt_template.get_start_of_function_call_token() + s += function_call_token + new_token = s[content_var] + function_call_token + elif gen_state["stage"] == "code-interpreter": + s += sgl.gen(name=content_var, stop=function_call_token) + if check_stop_condition(): + break + else: + s += function_call_token + new_token = s[content_var] + function_call_token + elif gen_state["stage"] == "pre-function": + s += function_call_token + new_token = function_call_token gen_state = prompt_template.update_fsm_gen_state( gen_state=gen_state, @@ -507,9 +537,65 @@ def generate_response(s: ProgramState, gen_state: Dict): tokenizer=tokenizer, ) - state = generate_response.run(gen_state=gen_state) - breakpoint() - # text_response = + state = generate_response.run( + gen_state=gen_state, + max_new_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + stream=request.stream, + ) + + chat_mess = prompt_template.parse_assistant_response( + llm_output=state.text()[len(prompt) :], tool_choice=tool_func_choice + ) + + # Convert tool_calls to function_call if request.functions is provided + if ( + request.functions + and "tool_calls" in chat_mess + and chat_mess["tool_calls"] is not None + and len(chat_mess["tool_calls"]) > 0 + ): + chat_mess["function_call"] = { + "name": chat_mess["tool_calls"][0]["function"]["name"], + "arguments": chat_mess["tool_calls"][0]["function"]["arguments"], + } + chat_mess["tool_calls"] = None + + # Postprocess finish reason + finish_reason = "stop" + if "function_call" in chat_mess and chat_mess["function_call"]: + finish_reason = "function_call" + if "tool_calls" in chat_mess and chat_mess["tool_calls"]: + finish_reason = "tool_calls" + + choices = [ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(**chat_mess), + # logprobs=choice_logprobs, + finish_reason=finish_reason, + ) + ] + + meta_info = state.get_meta_info(content_var) + + prompt_tokens = meta_info["prompt_tokens"] + completion_tokens = meta_info["completion_tokens"] + response = ChatCompletionResponse( + id=meta_info["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response def to_openai_style_logprobs( diff --git a/server_sglang.py b/server_sglang.py index 5ae7d83..77f6e1d 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -122,6 +122,7 @@ async def health_generate(request: Request) -> Response: @app.get("/get_model_info") async def get_model_info(): + global tokenizer_manager result = { "model_path": tokenizer_manager.model_path, "is_generation": tokenizer_manager.is_generation, @@ -309,14 +310,15 @@ def find_free_port(exclude_port: int) -> int: Returns: int: A free port number that is not the excluded port. """ + port = 10000 while True: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) + s.bind(("", port)) if s.getsockname()[1] != exclude_port: return s.getsockname()[1] except socket.error: - continue + port += 1 class FunctionaryRuntime(Runtime): From ab2036840c8f7ee651c84b2689719ca1cd7ad426 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Wed, 2 Oct 2024 15:15:35 +0000 Subject: [PATCH 17/40] fix usage tokens --- functionary/sglang_inference.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index f28d3d0..48eed53 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -453,11 +453,13 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): return_text=True, ) - recipient_var = "recipient" content_var = "content" + completion_tokens = 0 @sgl.function def generate_response(s: ProgramState, gen_state: Dict): + nonlocal completion_tokens + s += prompt # Form the options for the following stages @@ -493,11 +495,14 @@ def check_stop_condition(): if gen_state["add_code_interpreter"]: choices.append("python") s += sgl.select( - name=recipient_var, + name=content_var, choices=choices, choices_method=CHOICES_SAMPLING_METHOD, ) - new_token = s[recipient_var] + new_token = s[content_var] + completion_tokens += len( + tokenizer.encode(s[content_var], add_special_tokens=False) + ) elif gen_state["stage"] == "pre-parameter": s += prompt_template.fn_param_sep_token new_token = prompt_template.fn_param_sep_token @@ -509,17 +514,12 @@ def check_stop_condition(): ) s += sgl.gen(name=content_var, regex=regex, stop=function_call_token) new_token = s[content_var] + completion_tokens += s.get_meta_info(content_var)["completion_tokens"] if check_stop_condition(): break - elif gen_state["stage"] == "text-gen": - s += sgl.gen(name=content_var, stop=function_call_token) - if check_stop_condition(): - break - else: - s += function_call_token - new_token = s[content_var] + function_call_token - elif gen_state["stage"] == "code-interpreter": + elif gen_state["stage"] in ["text-gen", "code-interpreter"]: s += sgl.gen(name=content_var, stop=function_call_token) + completion_tokens += s.get_meta_info(content_var)["completion_tokens"] if check_stop_condition(): break else: @@ -581,12 +581,9 @@ def check_stop_condition(): ) ] - meta_info = state.get_meta_info(content_var) - - prompt_tokens = meta_info["prompt_tokens"] - completion_tokens = meta_info["completion_tokens"] + prompt_tokens = len(tokenizer.encode(prompt)) response = ChatCompletionResponse( - id=meta_info["id"], + id=state.get_meta_info(content_var)["id"], model=request.model, choices=choices, usage=UsageInfo( @@ -633,3 +630,4 @@ def append_top_logprobs(top_logprobs): append_top_logprobs(output_top_logprobs) return ret_logprobs + return ret_logprobs From 4b08ba02d453de363cbe774487d02017989474ce Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Wed, 2 Oct 2024 15:18:33 +0000 Subject: [PATCH 18/40] quick fix --- functionary/sglang_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 48eed53..c539019 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -630,4 +630,3 @@ def append_top_logprobs(top_logprobs): append_top_logprobs(output_top_logprobs) return ret_logprobs - return ret_logprobs From 26700f81e123c71e6c84a9d7b36dc5cc25a53353 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 3 Oct 2024 00:09:31 +0000 Subject: [PATCH 19/40] quick fix --- server_sglang.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/server_sglang.py b/server_sglang.py index 77f6e1d..338eb84 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -14,27 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -""" -The entry point of inference server. -SRT = SGLang Runtime. -""" - import argparse import asyncio import dataclasses From 123f8df7c2c88636b506075a683bf2def492479d Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 10 Oct 2024 13:03:47 +0000 Subject: [PATCH 20/40] wip --- functionary/sglang_inference.py | 74 ++++++++++++++++++++++- requirements_sgl.txt | 4 +- server_sglang.py | 104 +++++++++++++++----------------- 3 files changed, 125 insertions(+), 57 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index c539019..fce3ba8 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -409,6 +409,7 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): request_json = await raw_request.json() request = ChatCompletionRequest(**request_json) tokenizer = backend.get_tokenizer() + request_id = f"cmpl-{uuid.uuid4().hex}" # Convert legacy functions to tools if request.functions is not None: @@ -453,13 +454,15 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): return_text=True, ) - content_var = "content" + content_var = "content_" completion_tokens = 0 @sgl.function def generate_response(s: ProgramState, gen_state: Dict): nonlocal completion_tokens + idx = 0 + s += prompt # Form the options for the following stages @@ -484,6 +487,7 @@ def check_stop_condition(): return stop_match in stop_tokens while True: + content_var = f"content_{idx}" if gen_state["stage"] == "function": choices = [ tool["function"]["name"] @@ -503,6 +507,7 @@ def check_stop_condition(): completion_tokens += len( tokenizer.encode(s[content_var], add_special_tokens=False) ) + idx += 1 elif gen_state["stage"] == "pre-parameter": s += prompt_template.fn_param_sep_token new_token = prompt_template.fn_param_sep_token @@ -515,17 +520,20 @@ def check_stop_condition(): s += sgl.gen(name=content_var, regex=regex, stop=function_call_token) new_token = s[content_var] completion_tokens += s.get_meta_info(content_var)["completion_tokens"] + idx += 1 if check_stop_condition(): break elif gen_state["stage"] in ["text-gen", "code-interpreter"]: s += sgl.gen(name=content_var, stop=function_call_token) completion_tokens += s.get_meta_info(content_var)["completion_tokens"] + idx += 1 if check_stop_condition(): break else: s += function_call_token new_token = s[content_var] + function_call_token elif gen_state["stage"] == "pre-function": + breakpoint() s += function_call_token new_token = function_call_token @@ -548,6 +556,69 @@ def check_stop_condition(): stream=request.stream, ) + async def wrap_sgl_generator(): + nonlocal tokenizer, state + + for out in state.text_iter(): + if out.startswith(prompt): + continue + yield out, None + yield "", "stop" + + async def completion_stream_generator(functions): + generator = wrap_sgl_generator() + + tool_call_count = 0 + async for response in generate_openai_format_from_stream_async( + generator, prompt_template, tool_func_choice, tools_or_functions + ): + # Convert tool_calls to function_call if request.functions is provided + if ( + functions + and len(functions) > 0 + and "tool_calls" in response["delta"] + and response["delta"]["tool_calls"] + and len(response["delta"]["tool_calls"]) > 0 + ): + tool_name = response["delta"]["tool_calls"][0]["function"]["name"] + tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] + response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ + "function" + ] + response["delta"]["tool_calls"] = None + if tool_name and len(tool_name) > 0 and tool_args == "": + tool_call_count += 1 + + # Workaround Fixes + response["delta"]["role"] = "assistant" + if ( + "tool_calls" in response["delta"] + and response["delta"]["tool_calls"] + and len(response["delta"]["tool_calls"]) > 0 + ): + for tool_call in response["delta"]["tool_calls"]: + if tool_call.get("type") is None: + tool_call["type"] = "function" + + chunk = StreamChoice(**response) + result = ChatCompletionChunk( + id=request_id, choices=[chunk], model=request.model + ) + chunk_dic = result.dict(exclude_unset=True) + chunk_data = json.dumps(chunk_dic, ensure_ascii=False) + yield f"data: {chunk_data}\n\n" + # Break from for loop after the first tool_call is streamed if functions is provided + if functions and tool_call_count == 2: + break + yield "data: [DONE]\n\n" + + if request.stream: + return StreamingResponse( + completion_stream_generator(functions=request.functions), + media_type="text/event-stream", + # background=tokenizer_manager.create_abort_task(adapted_request), + ) + chat_mess = prompt_template.parse_assistant_response( llm_output=state.text()[len(prompt) :], tool_choice=tool_func_choice ) @@ -630,3 +701,4 @@ def append_top_logprobs(top_logprobs): append_top_logprobs(output_top_logprobs) return ret_logprobs + return ret_logprobs diff --git a/requirements_sgl.txt b/requirements_sgl.txt index 1f131da..6bffbd5 100644 --- a/requirements_sgl.txt +++ b/requirements_sgl.txt @@ -1,4 +1,4 @@ jsonref~=1.1.0 -sglang[all]==0.2.13 +sglang[all]==0.3.3 --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ -flashinfer==0.1.5 +flashinfer==0.1.6 diff --git a/server_sglang.py b/server_sglang.py index 338eb84..9fb09ba 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -16,6 +16,7 @@ # limitations under the License. import argparse import asyncio +import atexit import dataclasses import json import logging @@ -28,6 +29,8 @@ from http import HTTPStatus from typing import Dict, List, Optional, Union +import requests + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -48,8 +51,8 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, - allocate_init_ports, configure_logger, + is_port_available, prepare_model_and_tokenizer, ) @@ -101,7 +104,7 @@ async def health_generate(request: Request) -> Response: @app.get("/get_model_info") async def get_model_info(): - global tokenizer_manager + """Get the model information.""" result = { "model_path": tokenizer_manager.model_path, "is_generation": tokenizer_manager.is_generation, @@ -111,11 +114,13 @@ async def get_model_info(): @app.get("/get_server_args") async def get_server_args(): + """Get the server arguments.""" return dataclasses.asdict(tokenizer_manager.server_args) @app.get("/flush_cache") async def flush_cache(): + """Flush the radix cache.""" tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " @@ -124,6 +129,7 @@ async def flush_cache(): ) +# fastapi implicitly converts json in the request to obj (dataclass) async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: @@ -174,11 +180,11 @@ def available_models(): return ModelList(data=model_cards) -def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[mp.connection.Connection] = None, -): - """Launch an HTTP server.""" +def launch_engine(server_args: ServerArgs): + """ + Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess. + """ + global tokenizer_manager # Configure global environment @@ -187,18 +193,7 @@ def launch_server( _set_envs_and_config(server_args) # Allocate ports for inter-process communications - server_args.port, server_args.additional_ports = allocate_init_ports( - server_args.port, - server_args.additional_ports, - server_args.dp_size, - ) - ports = server_args.additional_ports - port_args = PortArgs( - tokenizer_port=ports[0], - scheduler_port=ports[1], - detokenizer_port=ports[2], - nccl_ports=ports[3:], - ) + port_args = PortArgs.init_new(server_args) logger.info(f"{server_args=}") # If using model from www.modelscope.cn, first download the model. @@ -255,6 +250,29 @@ def launch_server( for i in range(len(scheduler_pipe_readers)): scheduler_pipe_readers[i].recv() + +def launch_server( + server_args: ServerArgs, + pipe_finish_writer: Optional[mp.connection.Connection] = None, +): + """ + Launch SRT (SGLang Runtime) Server + + The SRT server consists of an HTTP server and the SRT engine. + + 1. HTTP server: A FastAPI server that routes requests to the engine. + 2. SRT engine: + 1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server and Tokenizer Manager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + + launch_engine(server_args=server_args) + # Add api key authorization if server_args.api_key: add_api_key_middleware(app, server_args.api_key) @@ -279,27 +297,6 @@ def launch_server( t.join() -def find_free_port(exclude_port: int) -> int: - """ - This function finds a free port that is not the excluded port. - - Args: - exclude_port (int): The port number to exclude from selection. - - Returns: - int: A free port number that is not the excluded port. - """ - port = 10000 - while True: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", port)) - if s.getsockname()[1] != exclude_port: - return s.getsockname()[1] - except socket.error: - port += 1 - - class FunctionaryRuntime(Runtime): """ A wrapper for the server. @@ -316,17 +313,18 @@ def __init__( """See the arguments in server_args.py::ServerArgs""" self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) + # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + # Pre-allocate ports - self.server_args.port, self.server_args.additional_ports = allocate_init_ports( - self.server_args.port, - self.server_args.additional_ports, - self.server_args.dp_size, - ) + for port in range(10000, 40000): + if is_port_available(port): + break + port += 1 + self.server_args.port = port self.url = self.server_args.url() - self.generate_url = ( - f"http://{self.server_args.host}:{self.server_args.port}/generate" - ) + self.generate_url = self.url + "/generate" self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) @@ -373,21 +371,19 @@ def __init__( server_args = ServerArgs.from_cli_args(args) if args.grammar_sampling: - wrapper_port = server_args.port - # Find a new random free port for the backend server runtime - server_args.port = find_free_port(exclude_port=wrapper_port) backend = FunctionaryRuntime(**vars(server_args)) sgl.set_default_backend( - sgl.RuntimeEndpoint(f"http://{server_args.host}:{server_args.port}") + sgl.RuntimeEndpoint( + f"http://{backend.server_args.host}:{backend.server_args.port}" + ) ) uvicorn.run( app, host=server_args.host, - port=wrapper_port, + port=server_args.port, log_level=server_args.log_level_http or server_args.log_level, timeout_keep_alive=5, loop="uvloop", ) - backend.shutdown() else: launch_server(server_args) From 89da8c16423bb7588e214b022d44983fcbac3d94 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 10 Oct 2024 13:34:02 +0000 Subject: [PATCH 21/40] optimize --- functionary/sglang_inference.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index fce3ba8..b02777b 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -454,15 +454,13 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): return_text=True, ) - content_var = "content_" + content_var = "content" completion_tokens = 0 @sgl.function def generate_response(s: ProgramState, gen_state: Dict): nonlocal completion_tokens - idx = 0 - s += prompt # Form the options for the following stages @@ -487,7 +485,6 @@ def check_stop_condition(): return stop_match in stop_tokens while True: - content_var = f"content_{idx}" if gen_state["stage"] == "function": choices = [ tool["function"]["name"] @@ -507,7 +504,6 @@ def check_stop_condition(): completion_tokens += len( tokenizer.encode(s[content_var], add_special_tokens=False) ) - idx += 1 elif gen_state["stage"] == "pre-parameter": s += prompt_template.fn_param_sep_token new_token = prompt_template.fn_param_sep_token @@ -520,13 +516,11 @@ def check_stop_condition(): s += sgl.gen(name=content_var, regex=regex, stop=function_call_token) new_token = s[content_var] completion_tokens += s.get_meta_info(content_var)["completion_tokens"] - idx += 1 if check_stop_condition(): break elif gen_state["stage"] in ["text-gen", "code-interpreter"]: s += sgl.gen(name=content_var, stop=function_call_token) completion_tokens += s.get_meta_info(content_var)["completion_tokens"] - idx += 1 if check_stop_condition(): break else: From 02d820a0188644da13287dc7dadf1023f82cb454 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 10 Oct 2024 14:17:45 +0000 Subject: [PATCH 22/40] remove workaround fixes --- .../llama31_prompt_template.py | 14 +++++----- .../prompt_template/llama3_prompt_template.py | 6 ++--- .../llama3_prompt_template_v3.py | 6 ++--- .../prompt_template/prompt_template_v2.py | 6 ++--- functionary/sglang_inference.py | 26 ------------------- 5 files changed, 16 insertions(+), 42 deletions(-) diff --git a/functionary/prompt_template/llama31_prompt_template.py b/functionary/prompt_template/llama31_prompt_template.py index c237ce1..11cdaa2 100644 --- a/functionary/prompt_template/llama31_prompt_template.py +++ b/functionary/prompt_template/llama31_prompt_template.py @@ -173,7 +173,7 @@ def stream_delta_text( if gen_state["stage"] in ["parameter", "code-interpreter"]: finish_reason = "tool_calls" return gen_state, prompt_utils.get_text_delta_response( - None, False, finish_reason + None, True, finish_reason ) responses = [] @@ -219,12 +219,12 @@ def stream_delta_text( gen_state["first_time_func"] = False responses.append( prompt_utils.get_function_delta_response( - gen_state, "", True, False, finish_reason + gen_state, "", True, True, finish_reason ) ) responses.append( prompt_utils.get_function_delta_response( - gen_state, gen_state["curr_text"], False, False, finish_reason + gen_state, gen_state["curr_text"], True, True, finish_reason ) ) @@ -233,7 +233,7 @@ def stream_delta_text( if len(delta_args) > 0: responses.append( prompt_utils.get_function_delta_response( - gen_state, delta_args, False, False, finish_reason + gen_state, delta_args, True, True, finish_reason ) ) elif " 0 - ): - for tool_call in response["delta"]["tool_calls"]: - if tool_call.get("type") is None: - tool_call["type"] = "function" - chunk = StreamChoice(**response) result = ChatCompletionChunk( id=adapted_request.rid, choices=[chunk], model=request.model @@ -527,7 +516,6 @@ def check_stop_condition(): s += function_call_token new_token = s[content_var] + function_call_token elif gen_state["stage"] == "pre-function": - breakpoint() s += function_call_token new_token = function_call_token @@ -551,8 +539,6 @@ def check_stop_condition(): ) async def wrap_sgl_generator(): - nonlocal tokenizer, state - for out in state.text_iter(): if out.startswith(prompt): continue @@ -583,17 +569,6 @@ async def completion_stream_generator(functions): if tool_name and len(tool_name) > 0 and tool_args == "": tool_call_count += 1 - # Workaround Fixes - response["delta"]["role"] = "assistant" - if ( - "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - for tool_call in response["delta"]["tool_calls"]: - if tool_call.get("type") is None: - tool_call["type"] = "function" - chunk = StreamChoice(**response) result = ChatCompletionChunk( id=request_id, choices=[chunk], model=request.model @@ -610,7 +585,6 @@ async def completion_stream_generator(functions): return StreamingResponse( completion_stream_generator(functions=request.functions), media_type="text/event-stream", - # background=tokenizer_manager.create_abort_task(adapted_request), ) chat_mess = prompt_template.parse_assistant_response( From e0eb31d583f8a83516aad3b4f3852c36d5413b1c Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Thu, 10 Oct 2024 14:18:29 +0000 Subject: [PATCH 23/40] fix --- functionary/sglang_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index efc5a05..2edce39 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -669,4 +669,3 @@ def append_top_logprobs(top_logprobs): append_top_logprobs(output_top_logprobs) return ret_logprobs - return ret_logprobs From 4ef5d65dec906b92a709044c59f6965411bac5cb Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Fri, 11 Oct 2024 09:36:17 +0000 Subject: [PATCH 24/40] fixes --- functionary/sglang_inference.py | 859 +++++++++++++++----------------- server_sglang.py | 14 +- 2 files changed, 400 insertions(+), 473 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 2edce39..b6ca887 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -79,19 +79,8 @@ def __init__(self, filename: str, purpose: str): # Choices sampling method for sgl.select CHOICES_SAMPLING_METHOD = greedy_token_selection - - -def format_finish_reason(finish_reason) -> Optional[str]: - if finish_reason.startswith("None"): - return None - elif finish_reason.startswith("FINISH_MATCHED"): - return "stop" - elif finish_reason.startswith("FINISH_LENGTH"): - return "length" - elif finish_reason.startswith("FINISH_ABORT"): - return "abort" - else: - return "unknown" +# Variable name for sgl frontend runtime generation +CONTENT_VAR = "content" def create_error_response( @@ -113,312 +102,91 @@ def create_streaming_error_response( return json_str -def v1_chat_generate_request(all_requests, tokenizer): - input_ids = [] - sampling_params_list = [] - image_data_list = [] - return_logprobs = [] - top_logprobs_nums = [] - for request in all_requests: - # Prep the data needed for the underlying GenerateReqInput: - # - prompt: The full prompt string. - # - stop: Custom stop tokens. - # - image_data: None or a list of image strings (URLs or base64 strings). - # None skips any image processing in GenerateReqInput. - tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice( - request=request - ) - if not isinstance(request.messages, str): - # Apply chat template and its stop strings. - prompt_ids = prepare_messages_for_inference( - tokenizer=tokenizer, - messages=request.messages, - tools_or_functions=tools_or_functions, - tool_choice=tool_func_choice, - device="cpu", - ).tolist()[0] - stop = ( - request.stop - + get_prompt_template_from_tokenizer( - tokenizer=tokenizer - ).get_stop_tokens_for_generation() - ) - image_data = None - else: - # Use the raw prompt and stop strings if the messages is already a string. - prompt_ids = request.messages - stop = request.stop - image_data = None - input_ids.append(prompt_ids) - return_logprobs.append(request.logprobs) - top_logprobs_nums.append(request.top_logprobs) - sampling_params_list.append( - { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "min_new_tokens": request.min_tokens, - "stop": stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "n": request.n, - } - ) - image_data_list.append(image_data) - if len(all_requests) == 1: - input_ids = input_ids[0] - if isinstance(input_ids, str): - prompt_kwargs = {"text": input_ids} - else: - prompt_kwargs = {"input_ids": input_ids} - sampling_params_list = sampling_params_list[0] - image_data = image_data_list[0] - return_logprobs = return_logprobs[0] - top_logprobs_nums = top_logprobs_nums[0] - else: - if isinstance(input_ids[0], str): - prompt_kwargs = {"text": input_ids} - else: - prompt_kwargs = {"input_ids": input_ids} - adapted_request = GenerateReqInput( - **prompt_kwargs, - image_data=image_data, - sampling_params=sampling_params_list, - return_logprob=return_logprobs, - top_logprobs_num=top_logprobs_nums, - stream=all_requests[0].stream, - return_text_in_logprobs=True, - ) - if len(all_requests) == 1: - return adapted_request, all_requests[0] - return adapted_request, all_requests - - -def v1_chat_generate_response(request, prompt_template, ret): - choices = [] - - _, tool_func_choice = analyze_tools_and_tool_choice(request=request) - - for idx, ret_item in enumerate(ret): - logprobs = False - if isinstance(request, list) and request[idx].logprobs: - logprobs = True - elif (not isinstance(request, list)) and request.logprobs: - logprobs = True - if logprobs: - logprobs = to_openai_style_logprobs( - output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], - output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], - ) - token_logprobs = [] - for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): - token_bytes = list(token.encode("utf-8")) - top_logprobs = [] - if logprobs.top_logprobs: - for top_token, top_logprob in logprobs.top_logprobs[0].items(): - top_token_bytes = list(top_token.encode("utf-8")) - top_logprobs.append( - TopLogprob( - token=top_token, - bytes=top_token_bytes, - logprob=top_logprob, - ) - ) - token_logprobs.append( - ChatCompletionTokenLogprob( - token=token, - bytes=token_bytes, - logprob=logprob, - top_logprobs=top_logprobs, - ) - ) - - choice_logprobs = ChoiceLogprobs(content=token_logprobs) - else: - choice_logprobs = None - - chat_mess = prompt_template.parse_assistant_response( - llm_output=ret_item["text"], tool_choice=tool_func_choice - ) - finish_reason = False - - # Convert tool_calls to function_call if request.functions is provided - if ( - request.functions - and "tool_calls" in chat_mess - and chat_mess["tool_calls"] is not None - and len(chat_mess["tool_calls"]) > 0 - ): - chat_mess["function_call"] = { - "name": chat_mess["tool_calls"][0]["function"]["name"], - "arguments": chat_mess["tool_calls"][0]["function"]["arguments"], - } - chat_mess["tool_calls"] = None - - # Postprocess finish reason - if "function_call" in chat_mess and chat_mess["function_call"]: - finish_reason = "function_call" - - if "tool_calls" in chat_mess and chat_mess["tool_calls"]: - finish_reason = "tool_calls" - - if not finish_reason: - finish_reason = format_finish_reason(ret_item["meta_info"]["finish_reason"]) +def convert_tool_calls_to_function_call( + functions: Optional[List[Function]], chat_message: Dict +): + if ( + functions + and len(functions) > 0 + and "tool_calls" in chat_message + and chat_message["tool_calls"] is not None + and len(chat_message["tool_calls"]) > 0 + ): + chat_message["function_call"] = { + "name": chat_message["tool_calls"][0]["function"]["name"], + "arguments": chat_message["tool_calls"][0]["function"]["arguments"], + } + chat_message["tool_calls"] = None - choice_data = ChatCompletionResponseChoice( - index=idx, - message=ChatMessage(**chat_mess), - # logprobs=choice_logprobs, - finish_reason=finish_reason, - ) + return chat_message - choices.append(choice_data) - prompt_tokens = sum( - ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) +def v1_chat_generate_request( + request, tokenizer, tools_or_functions, tool_func_choice, return_text=False +): + # Apply chat template and its stop strings. + input_ids = prepare_messages_for_inference( + tokenizer=tokenizer, + messages=request.messages, + tools_or_functions=tools_or_functions, + tool_choice=tool_func_choice, + device="cpu", + return_text=return_text, ) - completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) - response = ChatCompletionResponse( - id=ret[0]["meta_info"]["id"], - model=request.model, - choices=choices, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), + if not return_text: + input_ids = input_ids.tolist()[0] + + stop = ( + request.stop + + get_prompt_template_from_tokenizer( + tokenizer=tokenizer + ).get_stop_tokens_for_generation() ) - return response - - -async def v1_chat_completions(tokenizer_manager, raw_request: Request): - request_json = await raw_request.json() - all_requests = [ChatCompletionRequest(**request_json)] - tokenizer = tokenizer_manager.tokenizer + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "n": request.n, + } + + if isinstance(input_ids, str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} - prompt_template = get_prompt_template_from_tokenizer( - tokenizer=tokenizer_manager.tokenizer - ) - tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice( - all_requests[0] + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=None, + sampling_params=sampling_params, + return_logprob=request.logprobs, + top_logprobs_num=request.top_logprobs, + stream=request.stream, + return_text_in_logprobs=True, + rid=f"cmpl-{uuid.uuid4().hex}", ) - adapted_request, request = v1_chat_generate_request(all_requests, tokenizer) + return adapted_request, request - if adapted_request.stream: - - async def wrap_sgl_generator(): - stream_buffer = "" - async for content in tokenizer_manager.generate_request( - adapted_request, raw_request - ): - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] - text = content["text"] - delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - finish_reason = format_finish_reason( - content["meta_info"]["finish_reason"] - ) - - # If finish_reason is not None and delta_text is not empty, - # the delta_text is the eos_token and just remove it - if finish_reason is not None and len(delta) > 0: - delta = "" - yield delta, finish_reason - - async def completion_stream_generator(): - generator = wrap_sgl_generator() - - tool_call_count = 0 - async for response in generate_openai_format_from_stream_async( - generator, prompt_template, tool_func_choice, tools_or_functions - ): - # Convert tool_calls to function_call if request.functions is provided - if ( - request.functions - and len(request.functions) > 0 - and "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - tool_name = response["delta"]["tool_calls"][0]["function"]["name"] - tool_args = response["delta"]["tool_calls"][0]["function"][ - "arguments" - ] - response["delta"]["function_call"] = response["delta"][ - "tool_calls" - ][0]["function"] - response["delta"]["tool_calls"] = None - if tool_name and len(tool_name) > 0 and tool_args == "": - tool_call_count += 1 - # Return finish_reason after the first tool_call is streamed if functions is provided - if request.functions and tool_call_count == 2: - response["delta"] = {} - response["finish_reason"] = "function_call" - - chunk = StreamChoice(**response) - result = ChatCompletionChunk( - id=adapted_request.rid, choices=[chunk], model=request.model - ) - chunk_dic = result.dict(exclude_unset=True) - chunk_data = json.dumps(chunk_dic, ensure_ascii=False) - yield f"data: {chunk_data}\n\n" - # Break from for loop after the first tool_call is streamed if functions is provided - if request.functions and tool_call_count == 2: - break - yield "data: [DONE]\n\n" - - return StreamingResponse( - # generate_stream_resp(), - completion_stream_generator(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), - ) - - # Non-streaming response. - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - except ValueError as e: - return create_error_response(str(e)) - if not isinstance(ret, list): - ret = [ret] - - response = v1_chat_generate_response(request, prompt_template, ret) - - return response - - -async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): - request_json = await raw_request.json() - request = ChatCompletionRequest(**request_json) - tokenizer = backend.get_tokenizer() - request_id = f"cmpl-{uuid.uuid4().hex}" - - # Convert legacy functions to tools - if request.functions is not None: - request.tools = [ - Tool(type="function", function=function) for function in request.functions - ] - # Convert legacy function_call to tool_choice - if request.function_call is not None: - if isinstance(request.function_call, str) and ( - request.function_call == "none" or request.function_call == "auto" - ): - request.tool_choice = request.function_call - if request.function_call and isinstance(request.function_call, Function): - request.tool_choice = Tool( - type="function", function=Function(name=request.function_call.name) - ) - - prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer) - tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request) +@sgl.function +def generate_sglang_srt_response( + s: ProgramState, + prompt: str, + prompt_template, + tools_or_functions, + tool_func_choice, + tokenizer, +): + completion_tokens = 0 + stop_tokens = prompt_template.get_stop_tokens_for_generation() + function_call_token = prompt_template.get_start_of_function_call_token() gen_state = prompt_template.initialize_fsm_gen_state( tool_choice=tool_func_choice, curr_text="", @@ -435,174 +203,287 @@ async def v1_chat_completions_grammar_sampling(backend, raw_request: Request): else False ), ) - prompt = prepare_messages_for_inference( - tokenizer=tokenizer, - messages=request.messages, - tools_or_functions=tools_or_functions, - tool_choice=tool_func_choice, - return_text=True, + # Form the options for the following stages + tools = [] + for tool in tools_or_functions: + if "type" in tool: + if tool["type"] == "function": + tools.append(tool["function"]) + else: + tools.append(tool) + options = prompt_template.get_options_from_gen_state( + gen_state=gen_state, tools_or_functions=tools ) - content_var = "content" - completion_tokens = 0 - - @sgl.function - def generate_response(s: ProgramState, gen_state: Dict): - nonlocal completion_tokens - - s += prompt - - # Form the options for the following stages - tools = [] - for tool in tools_or_functions: - if "type" in tool: - if tool["type"] == "function": - tools.append(tool["function"]) + def check_stop_condition(): + stop_match = s.get_meta_info(CONTENT_VAR)["finish_reason"]["matched"] + if not isinstance(stop_match, str): + stop_match = tokenizer.decode(stop_match) + return stop_match in stop_tokens + + s += prompt + while True: + if gen_state["stage"] == "function": + choices = [ + tool["function"]["name"] + for tool in tools_or_functions + if tool["type"] == "function" + ] + if gen_state["add_all_recipient"]: + choices.append("all") + if gen_state["add_code_interpreter"]: + choices.append("python") + s += sgl.select( + name=CONTENT_VAR, + choices=choices, + choices_method=CHOICES_SAMPLING_METHOD, + ) + new_token = s[CONTENT_VAR] + completion_tokens += len( + tokenizer.encode(s[CONTENT_VAR], add_special_tokens=False) + ) + elif gen_state["stage"] == "pre-parameter": + s += prompt_template.fn_param_sep_token + new_token = prompt_template.fn_param_sep_token + elif gen_state["stage"] == "parameter": + tool = next(t for t in tools if t["name"] == gen_state["func_name"]) + regex = ( + build_regex_from_schema(json.dumps(tool["parameters"])) + + f"({re.escape(function_call_token)})?" + ) + s += sgl.gen(name=CONTENT_VAR, regex=regex, stop=function_call_token) + new_token = s[CONTENT_VAR] + completion_tokens += s.get_meta_info(CONTENT_VAR)["completion_tokens"] + if check_stop_condition(): + break + elif gen_state["stage"] in ["text-gen", "code-interpreter"]: + s += sgl.gen(name=CONTENT_VAR, stop=function_call_token) + completion_tokens += s.get_meta_info(CONTENT_VAR)["completion_tokens"] + if check_stop_condition(): + break else: - tools.append(tool) - options = prompt_template.get_options_from_gen_state( - gen_state=gen_state, tools_or_functions=tools - ) - - stop_tokens = prompt_template.get_stop_tokens_for_generation() - function_call_token = prompt_template.get_start_of_function_call_token() - - def check_stop_condition(): - stop_match = s.get_meta_info(content_var)["finish_reason"]["matched"] - if not isinstance(stop_match, str): - stop_match = tokenizer.decode(stop_match) - return stop_match in stop_tokens - - while True: - if gen_state["stage"] == "function": - choices = [ - tool["function"]["name"] - for tool in tools_or_functions - if tool["type"] == "function" - ] - if gen_state["add_all_recipient"]: - choices.append("all") - if gen_state["add_code_interpreter"]: - choices.append("python") - s += sgl.select( - name=content_var, - choices=choices, - choices_method=CHOICES_SAMPLING_METHOD, - ) - new_token = s[content_var] - completion_tokens += len( - tokenizer.encode(s[content_var], add_special_tokens=False) - ) - elif gen_state["stage"] == "pre-parameter": - s += prompt_template.fn_param_sep_token - new_token = prompt_template.fn_param_sep_token - elif gen_state["stage"] == "parameter": - tool = next(t for t in tools if t["name"] == gen_state["func_name"]) - regex = ( - build_regex_from_schema(json.dumps(tool["parameters"])) - + f"({re.escape(function_call_token)})?" - ) - s += sgl.gen(name=content_var, regex=regex, stop=function_call_token) - new_token = s[content_var] - completion_tokens += s.get_meta_info(content_var)["completion_tokens"] - if check_stop_condition(): - break - elif gen_state["stage"] in ["text-gen", "code-interpreter"]: - s += sgl.gen(name=content_var, stop=function_call_token) - completion_tokens += s.get_meta_info(content_var)["completion_tokens"] - if check_stop_condition(): - break - else: - s += function_call_token - new_token = s[content_var] + function_call_token - elif gen_state["stage"] == "pre-function": s += function_call_token - new_token = function_call_token - - gen_state = prompt_template.update_fsm_gen_state( - gen_state=gen_state, - new_token=new_token, - new_token_id=None, - options=options, - tokenizer=tokenizer, - ) + new_token = s[CONTENT_VAR] + function_call_token + elif gen_state["stage"] == "pre-function": + s += function_call_token + new_token = function_call_token + gen_state = prompt_template.update_fsm_gen_state( + gen_state=gen_state, + new_token=new_token, + new_token_id=None, + options=options, + tokenizer=tokenizer, + ) - state = generate_response.run( - gen_state=gen_state, - max_new_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - frequency_penalty=request.frequency_penalty, - presence_penalty=request.presence_penalty, - stream=request.stream, - ) - async def wrap_sgl_generator(): - for out in state.text_iter(): +async def wrap_sgl_generator( + adapted_request, + raw_request, + request, + tokenizer, + tokenizer_manager, + backend, + prompt_template, + tools_or_functions, + tool_func_choice, + frontend_state, + grammar_sampling, +): + if grammar_sampling: + prompt = ( + adapted_request.text + if adapted_request.text + else tokenizer.decode(adapted_request.input_ids) + ) + for out in frontend_state.text_iter(): if out.startswith(prompt): continue yield out, None yield "", "stop" - - async def completion_stream_generator(functions): - generator = wrap_sgl_generator() - - tool_call_count = 0 - async for response in generate_openai_format_from_stream_async( - generator, prompt_template, tool_func_choice, tools_or_functions + else: + stream_buffer = "" + async for content in tokenizer_manager.generate_request( + adapted_request, raw_request ): - # Convert tool_calls to function_call if request.functions is provided - if ( - functions - and len(functions) > 0 - and "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - tool_name = response["delta"]["tool_calls"][0]["function"]["name"] - tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] - response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ - "function" - ] - response["delta"]["tool_calls"] = None - if tool_name and len(tool_name) > 0 and tool_args == "": - tool_call_count += 1 - - chunk = StreamChoice(**response) - result = ChatCompletionChunk( - id=request_id, choices=[chunk], model=request.model - ) - chunk_dic = result.dict(exclude_unset=True) - chunk_data = json.dumps(chunk_dic, ensure_ascii=False) - yield f"data: {chunk_data}\n\n" - # Break from for loop after the first tool_call is streamed if functions is provided - if functions and tool_call_count == 2: - break - yield "data: [DONE]\n\n" + text = content["text"] + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + finish_reason = content["meta_info"]["finish_reason"] + + # If finish_reason is not None and delta_text is not empty, + # the delta_text is the eos_token and just remove it + if finish_reason is not None and len(delta) > 0: + delta = "" + yield delta, finish_reason + + +async def completion_stream_generator( + adapted_request, + raw_request, + request, + tokenizer, + tokenizer_manager, + backend, + prompt_template, + tools_or_functions, + tool_func_choice, + frontend_state, + grammar_sampling, +): + generator = wrap_sgl_generator( + adapted_request, + raw_request, + request, + tokenizer, + tokenizer_manager, + backend, + prompt_template, + tools_or_functions, + tool_func_choice, + frontend_state, + grammar_sampling, + ) - if request.stream: - return StreamingResponse( - completion_stream_generator(functions=request.functions), - media_type="text/event-stream", + tool_call_count = 0 + async for response in generate_openai_format_from_stream_async( + generator, prompt_template, tool_func_choice, tools_or_functions + ): + # Convert tool_calls to function_call if request.functions is provided + if ( + request.functions + and len(request.functions) > 0 + and "tool_calls" in response["delta"] + and response["delta"]["tool_calls"] + and len(response["delta"]["tool_calls"]) > 0 + ): + tool_name = response["delta"]["tool_calls"][0]["function"]["name"] + tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] + response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ + "function" + ] + response["delta"]["tool_calls"] = None + if tool_name and len(tool_name) > 0 and tool_args == "": + tool_call_count += 1 + + # Return finish_reason after the first tool_call is streamed if functions is provided + if request.functions and tool_call_count == 2: + response["delta"] = {} + response["finish_reason"] = "function_call" + + chunk = StreamChoice(**response) + result = ChatCompletionChunk( + id=adapted_request.rid, choices=[chunk], model=request.model ) - + chunk_dic = result.dict(exclude_unset=True) + chunk_data = json.dumps(chunk_dic, ensure_ascii=False) + yield f"data: {chunk_data}\n\n" + # Break from for loop after the first tool_call is streamed if functions is provided + if request.functions and tool_call_count == 2: + break + yield "data: [DONE]\n\n" + + +async def v1_chat_generate_completion( + adapted_request, + raw_request, + request, + tokenizer, + tokenizer_manager, + backend, + prompt_template, + tools_or_functions, + tool_func_choice, +): + grammar_sampling = True if backend else False + if grammar_sampling: + prompt = ( + adapted_request.text + if adapted_request.text + else tokenizer.decode(adapted_request.input_ids) + ) + state = generate_sglang_srt_response.run( + prompt=prompt, + prompt_template=prompt_template, + tools_or_functions=tools_or_functions, + tool_func_choice=tool_func_choice, + tokenizer=tokenizer, + max_new_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + stream=request.stream, + ) + if adapted_request.stream: + return ( + StreamingResponse( + completion_stream_generator( + adapted_request, + raw_request, + request, + tokenizer, + tokenizer_manager, + backend, + prompt_template, + tools_or_functions, + tool_func_choice, + state, + grammar_sampling, + ), + media_type="text/event-stream", + ), + None, + ) + else: + return state.text()[len(prompt) :], None + else: + if adapted_request.stream: + return ( + StreamingResponse( + completion_stream_generator( + adapted_request, + raw_request, + request, + tokenizer, + tokenizer_manager, + backend, + prompt_template, + tools_or_functions, + tool_func_choice, + None, + grammar_sampling, + ), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request), + ), + None, + ) + else: + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return None, create_error_response(str(e)) + return ret["text"], None + + +def v1_chat_generate_response( + adapted_request, + raw_request, + request, + output_text, + prompt_template, + tokenizer, + tool_func_choice, +): chat_mess = prompt_template.parse_assistant_response( - llm_output=state.text()[len(prompt) :], tool_choice=tool_func_choice + llm_output=output_text, tool_choice=tool_func_choice + ) + chat_mess = convert_tool_calls_to_function_call( + functions=request.functions, chat_message=chat_mess ) - - # Convert tool_calls to function_call if request.functions is provided - if ( - request.functions - and "tool_calls" in chat_mess - and chat_mess["tool_calls"] is not None - and len(chat_mess["tool_calls"]) > 0 - ): - chat_mess["function_call"] = { - "name": chat_mess["tool_calls"][0]["function"]["name"], - "arguments": chat_mess["tool_calls"][0]["function"]["arguments"], - } - chat_mess["tool_calls"] = None # Postprocess finish reason finish_reason = "stop" @@ -615,14 +496,18 @@ async def completion_stream_generator(functions): ChatCompletionResponseChoice( index=0, message=ChatMessage(**chat_mess), - # logprobs=choice_logprobs, finish_reason=finish_reason, ) ] + prompt_tokens = ( + len(adapted_request.input_ids) + if adapted_request.input_ids + else len(tokenizer.encode(adapted_request.text)) + ) + completion_tokens = len(tokenizer.encode(output_text, add_special_tokens=False)) + 1 - prompt_tokens = len(tokenizer.encode(prompt)) response = ChatCompletionResponse( - id=state.get_meta_info(content_var)["id"], + id=adapted_request.rid, model=request.model, choices=choices, usage=UsageInfo( @@ -634,6 +519,50 @@ async def completion_stream_generator(functions): return response +async def v1_chat_completions(tokenizer_manager, backend, raw_request: Request): + request_json = await raw_request.json() + request = ChatCompletionRequest(**request_json) + tokenizer = ( + tokenizer_manager.tokenizer if tokenizer_manager else backend.get_tokenizer() + ) + + prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer) + tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request) + + adapted_request, request = v1_chat_generate_request( + request, tokenizer, tools_or_functions, tool_func_choice, return_text=False + ) + + output, error = await v1_chat_generate_completion( + adapted_request=adapted_request, + raw_request=raw_request, + request=request, + tokenizer=tokenizer, + tokenizer_manager=tokenizer_manager, + backend=backend, + prompt_template=prompt_template, + tools_or_functions=tools_or_functions, + tool_func_choice=tool_func_choice, + ) + if error: + return error + + if adapted_request.stream: + return output + + response = v1_chat_generate_response( + adapted_request=adapted_request, + raw_request=raw_request, + request=request, + output_text=output, + prompt_template=prompt_template, + tokenizer=tokenizer, + tool_func_choice=tool_func_choice, + ) + + return response + + def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, diff --git a/server_sglang.py b/server_sglang.py index 9fb09ba..29bac1e 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -56,10 +56,7 @@ prepare_model_and_tokenizer, ) -from functionary.sglang_inference import ( - v1_chat_completions, - v1_chat_completions_grammar_sampling, -) +from functionary.sglang_inference import v1_chat_completions from functionary.sglang_monkey_patch.tokenizer_manager import ( MonkeyPatchTokenizerManager, ) @@ -164,10 +161,11 @@ async def stream_results(): @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - if args.grammar_sampling: - return await v1_chat_completions_grammar_sampling(backend, raw_request) - else: - return await v1_chat_completions(tokenizer_manager, raw_request) + global tokenizer_manager, backend + + if not args.grammar_sampling: + backend = None + return await v1_chat_completions(tokenizer_manager, backend, raw_request) @app.get("/v1/models") From e8faa0f2127c4a67f8727243adf25b0ab8d78a52 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Fri, 11 Oct 2024 10:16:07 +0000 Subject: [PATCH 25/40] fixes --- functionary/sglang_inference.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index b6ca887..6b50076 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -247,13 +247,12 @@ def check_stop_condition(): new_token = prompt_template.fn_param_sep_token elif gen_state["stage"] == "parameter": tool = next(t for t in tools if t["name"] == gen_state["func_name"]) - regex = ( - build_regex_from_schema(json.dumps(tool["parameters"])) - + f"({re.escape(function_call_token)})?" - ) + regex = build_regex_from_schema(json.dumps(tool["parameters"])) s += sgl.gen(name=CONTENT_VAR, regex=regex, stop=function_call_token) new_token = s[CONTENT_VAR] completion_tokens += s.get_meta_info(CONTENT_VAR)["completion_tokens"] + # Generate new token to determin if there is another tool call + s += sgl.gen(name=CONTENT_VAR, stop=function_call_token) if check_stop_condition(): break elif gen_state["stage"] in ["text-gen", "code-interpreter"]: From e3c9de50ba461920684e1d0cc9ceb2baa35ee8fe Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Fri, 11 Oct 2024 15:35:30 +0000 Subject: [PATCH 26/40] fixes --- functionary/sglang_inference.py | 459 ++++++++++++++++++-------------- 1 file changed, 255 insertions(+), 204 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 6b50076..0ea6614 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -21,8 +21,9 @@ import re import time import uuid +from dataclasses import dataclass from http import HTTPStatus -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union import sglang as sgl from fastapi import Request @@ -31,16 +32,10 @@ from sglang.lang.choices import greedy_token_selection from sglang.lang.interpreter import ProgramState from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.openai_api.protocol import ( - BatchResponse, - ChatCompletionTokenLogprob, - ChoiceLogprobs, - DeltaMessage, - ErrorResponse, - FileResponse, - LogProbs, - TopLogprob, -) +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.openai_api.protocol import ErrorResponse +from sglang.srt.server import Runtime +from transformers import AutoTokenizer from functionary.inference_stream import generate_openai_format_from_stream_async from functionary.inference_utils import analyze_tools_and_tool_choice @@ -55,7 +50,10 @@ Tool, UsageInfo, ) -from functionary.prompt_template import get_prompt_template_from_tokenizer +from functionary.prompt_template import ( + PromptTemplate, + get_prompt_template_from_tokenizer, +) from functionary.prompt_template.prompt_utils import prepare_messages_for_inference @@ -65,24 +63,29 @@ def __init__(self, filename: str, purpose: str): self.purpose = purpose -# In-memory storage for batch jobs and files -batch_storage: Dict[str, BatchResponse] = {} -file_id_request: Dict[str, FileMetadata] = {} -file_id_response: Dict[str, FileResponse] = {} -# map file id to file path in SGLang backend -file_id_storage: Dict[str, str] = {} - - -# backend storage directory -storage_dir = None - - # Choices sampling method for sgl.select CHOICES_SAMPLING_METHOD = greedy_token_selection # Variable name for sgl frontend runtime generation CONTENT_VAR = "content" +@dataclass +class ChatCompletionParams: + """Parameters and context used across various chat completion functions""" + + adapted_request: GenerateReqInput + raw_request: Request + request: ChatCompletionRequest + tokenizer: AutoTokenizer + tokenizer_manager: Optional[TokenizerManager] + srt_backend: Optional[Runtime] + prompt_template: PromptTemplate + tools_or_functions: List[Dict] + tool_func_choice: Optional[Union[str, Tool, Function]] + frontend_state: Optional[ProgramState] + grammar_sampling: bool + + def create_error_response( message: str, err_type: str = "BadRequestError", @@ -104,7 +107,7 @@ def create_streaming_error_response( def convert_tool_calls_to_function_call( functions: Optional[List[Function]], chat_message: Dict -): +) -> Dict: if ( functions and len(functions) > 0 @@ -122,8 +125,34 @@ def convert_tool_calls_to_function_call( def v1_chat_generate_request( - request, tokenizer, tools_or_functions, tool_func_choice, return_text=False -): + request: ChatCompletionRequest, + tokenizer: AutoTokenizer, + tools_or_functions: List[Dict], + tool_func_choice: Optional[Union[str, Tool, Function]], + return_text: bool = False, +) -> Tuple[GenerateReqInput, ChatCompletionRequest]: + """ + Generate an adapted request that SGLang uses. + + This function prepares the input for SGLang inference by processing the chat completion request, + applying the appropriate tokenization, and setting up the sampling parameters. + + Args: + request (ChatCompletionRequest): The original chat completion request. + tokenizer (AutoTokenizer): The tokenizer to use for encoding the text input, if any. + tools_or_functions (List[Dict]): List of available tools or functions. + tool_func_choice (Optional[Union[str, Tool, Function]]): The chosen tool or function, if any. + return_text (bool, optional): Whether to return the input as text instead of token IDs. Defaults to False. + + Returns: + Tuple[GenerateReqInput, ChatCompletionRequest]: A tuple containing: + - The adapted request (GenerateReqInput) to be used by SGLang. + - The original request (ChatCompletionRequest), NOT modified. + + Note: + This function handles the conversion of the chat messages into a format suitable for SGLang, + applies the chat template, sets up stopping criteria, and configures sampling parameters. + """ # Apply chat template and its stop strings. input_ids = prepare_messages_for_inference( tokenizer=tokenizer, @@ -184,6 +213,24 @@ def generate_sglang_srt_response( tool_func_choice, tokenizer, ): + """ + Generate a response using SGLang Frontend Runtime (SRT). + + This function is used when grammar-sampling is enabled. It uses the SRT program + state to update the specific prompt-template Finite State Machine (FSM) generation + state. Constrained generation is performed at specific stages of the FSM. + + Args: + s (ProgramState): The current program state in SGLang. + prompt (str): The input prompt to generate a response for. + prompt_template: The template used to structure the prompt and response. + tools_or_functions (list): Available tools or functions for the model to use. + tool_func_choice (str): The chosen tool or function choice. + tokenizer: The tokenizer used for encoding and decoding text. + + Returns: + ProgramState: The updated program state after generating the response. + """ completion_tokens = 0 stop_tokens = prompt_template.get_stop_tokens_for_generation() function_call_token = prompt_template.get_start_of_function_call_token() @@ -275,34 +322,38 @@ def check_stop_condition(): ) -async def wrap_sgl_generator( - adapted_request, - raw_request, - request, - tokenizer, - tokenizer_manager, - backend, - prompt_template, - tools_or_functions, - tool_func_choice, - frontend_state, - grammar_sampling, -): - if grammar_sampling: +async def wrap_sgl_generator(params: ChatCompletionParams): + """ + This asynchronous generator function yields generated text chunks along + with their finish reasons. + + Args: + params (ChatCompletionParams): A dataclass containing all necessary + parameters for the chat completion, including the request details, + tokenizer, backend, and other configuration options. + + Yields: + Tuple[str, Optional[str]]: A tuple containing: + - str: The generated text chunk. + - Optional[str]: The finish reason, if any (e.g., "stop", "length", etc.). + """ + if params.grammar_sampling: prompt = ( - adapted_request.text - if adapted_request.text - else tokenizer.decode(adapted_request.input_ids) + params.adapted_request.text + if params.adapted_request.text + else params.tokenizer.decode(params.adapted_request.input_ids) ) - for out in frontend_state.text_iter(): + # Iterates over the text generated by the SGLang Frontend Runtime + for out in params.frontend_state.text_iter(): if out.startswith(prompt): continue yield out, None yield "", "stop" else: + # Iterates over the text generated by the tokenizer manager stream_buffer = "" - async for content in tokenizer_manager.generate_request( - adapted_request, raw_request + async for content in params.tokenizer_manager.generate_request( + params.adapted_request, params.raw_request ): text = content["text"] delta = text[len(stream_buffer) :] @@ -316,41 +367,43 @@ async def wrap_sgl_generator( yield delta, finish_reason -async def completion_stream_generator( - adapted_request, - raw_request, - request, - tokenizer, - tokenizer_manager, - backend, - prompt_template, - tools_or_functions, - tool_func_choice, - frontend_state, - grammar_sampling, -): - generator = wrap_sgl_generator( - adapted_request, - raw_request, - request, - tokenizer, - tokenizer_manager, - backend, - prompt_template, - tools_or_functions, - tool_func_choice, - frontend_state, - grammar_sampling, - ) +async def completion_stream_generator(params: ChatCompletionParams): + """ + This asynchronous generator function produces a stream of ChatCompletionChunk + objects. It handles both grammar-sampling and regular generations, + depending on the parameters provided. + + Args: + params (ChatCompletionParams): A dataclass containing all necessary + parameters for the chat completion, including the request details, + tokenizer, backend, and other configuration options. + + Yields: + str: JSON-formatted strings representing chunks of the chat completion + response, including delta updates and finish reasons. + + Notes: + - The function adapts its behavior based on whether grammar sampling + is enabled or not. + - It handles the conversion of tool calls to function calls when + appropriate. + - The stream is terminated with a "[DONE]" message. + """ + # Initialize the text generator + generator = wrap_sgl_generator(params) tool_call_count = 0 + # Generate the text in openai format async for response in generate_openai_format_from_stream_async( - generator, prompt_template, tool_func_choice, tools_or_functions + generator, + params.prompt_template, + params.tool_func_choice, + params.tools_or_functions, ): # Convert tool_calls to function_call if request.functions is provided if ( - request.functions - and len(request.functions) > 0 + params.request.functions + and len(params.request.functions) > 0 and "tool_calls" in response["delta"] and response["delta"]["tool_calls"] and len(response["delta"]["tool_calls"]) > 0 @@ -365,71 +418,75 @@ async def completion_stream_generator( tool_call_count += 1 # Return finish_reason after the first tool_call is streamed if functions is provided - if request.functions and tool_call_count == 2: + if params.request.functions and tool_call_count == 2: response["delta"] = {} response["finish_reason"] = "function_call" chunk = StreamChoice(**response) result = ChatCompletionChunk( - id=adapted_request.rid, choices=[chunk], model=request.model + id=params.adapted_request.rid, choices=[chunk], model=params.request.model ) chunk_dic = result.dict(exclude_unset=True) chunk_data = json.dumps(chunk_dic, ensure_ascii=False) yield f"data: {chunk_data}\n\n" # Break from for loop after the first tool_call is streamed if functions is provided - if request.functions and tool_call_count == 2: + if params.request.functions and tool_call_count == 2: break yield "data: [DONE]\n\n" async def v1_chat_generate_completion( - adapted_request, - raw_request, - request, - tokenizer, - tokenizer_manager, - backend, - prompt_template, - tools_or_functions, - tool_func_choice, -): - grammar_sampling = True if backend else False - if grammar_sampling: + params: ChatCompletionParams, +) -> Tuple[Union[StreamingResponse, str], Optional[JSONResponse]]: + """ + Generate a text completion. + + This function handles both streaming and non-streaming responses for chat completions. + It supports both regular and grammar-sampling generations. + + Args: + params (ChatCompletionParams): A dataclass containing all necessary parameters and context + for generating the text. + + Returns: + Tuple[Union[StreamingResponse, str], Optional[JSONResponse]]: + - If streaming is requested, returns a StreamingResponse object. + - If non-streaming, returns the generated text as a string. + - The second element is an optional JSONResponse for error cases. + + Note: + - For grammar-sampling, it uses the SGLang Frontend Runtime. + - For regular generation, it uses the tokenizer manager to generate the response. + - Streaming responses are handled by the completion_stream_generator function. + """ + # If streaming, return the StreamingResponse else return the text + if params.grammar_sampling: + # Form the text prompt and run the SGLang Frontend Runtime prompt = ( - adapted_request.text - if adapted_request.text - else tokenizer.decode(adapted_request.input_ids) + params.adapted_request.text + if params.adapted_request.text + else params.tokenizer.decode(params.adapted_request.input_ids) ) state = generate_sglang_srt_response.run( prompt=prompt, - prompt_template=prompt_template, - tools_or_functions=tools_or_functions, - tool_func_choice=tool_func_choice, - tokenizer=tokenizer, - max_new_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - frequency_penalty=request.frequency_penalty, - presence_penalty=request.presence_penalty, - stream=request.stream, + prompt_template=params.prompt_template, + tools_or_functions=params.tools_or_functions, + tool_func_choice=params.tool_func_choice, + tokenizer=params.tokenizer, + max_new_tokens=params.request.max_tokens, + temperature=params.request.temperature, + top_p=params.request.top_p, + top_k=params.request.top_k, + frequency_penalty=params.request.frequency_penalty, + presence_penalty=params.request.presence_penalty, + stream=params.request.stream, ) - if adapted_request.stream: + + if params.adapted_request.stream: + params.frontend_state = state return ( StreamingResponse( - completion_stream_generator( - adapted_request, - raw_request, - request, - tokenizer, - tokenizer_manager, - backend, - prompt_template, - tools_or_functions, - tool_func_choice, - state, - grammar_sampling, - ), + completion_stream_generator(params), media_type="text/event-stream", ), None, @@ -437,31 +494,21 @@ async def v1_chat_generate_completion( else: return state.text()[len(prompt) :], None else: - if adapted_request.stream: + if params.adapted_request.stream: return ( StreamingResponse( - completion_stream_generator( - adapted_request, - raw_request, - request, - tokenizer, - tokenizer_manager, - backend, - prompt_template, - tools_or_functions, - tool_func_choice, - None, - grammar_sampling, - ), + completion_stream_generator(params), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), + background=params.tokenizer_manager.create_abort_task( + params.adapted_request + ), ), None, ) else: try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request + ret = await params.tokenizer_manager.generate_request( + params.adapted_request, params.raw_request ).__anext__() except ValueError as e: return None, create_error_response(str(e)) @@ -469,19 +516,29 @@ async def v1_chat_generate_completion( def v1_chat_generate_response( - adapted_request, - raw_request, - request, - output_text, - prompt_template, - tokenizer, - tool_func_choice, -): - chat_mess = prompt_template.parse_assistant_response( - llm_output=output_text, tool_choice=tool_func_choice + output_text: str, params: ChatCompletionParams +) -> ChatCompletionResponse: + """ + Generate a ChatCompletionResponse from the output text and parameters. + + This function processes the output text, parses it according to the prompt template, + and constructs a ChatCompletionResponse object. + + Args: + output_text (str): The raw output text from SGLang inference. + params (ChatCompletionParams): Parameters and context for the chat completion. + + Returns: + ChatCompletionResponse: An OpenAI-compatible response containing the assistant's message, + usage information, and other metadata. + """ + # Parse the output text using the specific prompt template + chat_mess = params.prompt_template.parse_assistant_response( + llm_output=output_text, tool_choice=params.tool_func_choice ) + # Convert tool_calls to function_call if request.functions is provided chat_mess = convert_tool_calls_to_function_call( - functions=request.functions, chat_message=chat_mess + functions=params.request.functions, chat_message=chat_mess ) # Postprocess finish reason @@ -499,15 +556,17 @@ def v1_chat_generate_response( ) ] prompt_tokens = ( - len(adapted_request.input_ids) - if adapted_request.input_ids - else len(tokenizer.encode(adapted_request.text)) + len(params.adapted_request.input_ids) + if params.adapted_request.input_ids + else len(params.tokenizer.encode(params.adapted_request.text)) ) - completion_tokens = len(tokenizer.encode(output_text, add_special_tokens=False)) + 1 + completion_tokens = ( + len(params.tokenizer.encode(output_text, add_special_tokens=False)) + 1 + ) # +1 for the eos token response = ChatCompletionResponse( - id=adapted_request.rid, - model=request.model, + id=params.adapted_request.rid, + model=params.request.model, choices=choices, usage=UsageInfo( prompt_tokens=prompt_tokens, @@ -518,82 +577,74 @@ def v1_chat_generate_response( return response -async def v1_chat_completions(tokenizer_manager, backend, raw_request: Request): +async def v1_chat_completions( + tokenizer_manager: Optional[TokenizerManager], + srt_backend: Optional[Runtime], + raw_request: Request, +): + """ + Handle chat completions for v1 of the API. + + This function processes the incoming request, prepares the necessary parameters, + generates the chat completion, and returns the response. It supports both + streaming and non-streaming responses. + + Args: + tokenizer_manager (Optional[TokenizerManager]): Manager for tokenization tasks. + None if grammar sampling is enabled. + srt_backend (Optional[Runtime]): The SRT backend for processing. + None if grammar sampling is disabled. + raw_request (Request): The raw incoming request object. + + Returns: + Union[ChatCompletionResponse, StreamingResponse, JSONResponse]: + - ChatCompletionResponse for non-streaming successful responses. + - StreamingResponse for streaming responses. + - JSONResponse for error responses. + + Raises: + No explicit raises, but may return error responses for various failure scenarios. + """ request_json = await raw_request.json() request = ChatCompletionRequest(**request_json) tokenizer = ( - tokenizer_manager.tokenizer if tokenizer_manager else backend.get_tokenizer() + tokenizer_manager.tokenizer + if tokenizer_manager + else srt_backend.get_tokenizer() ) - prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer) tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request) + # Generate the adapted request adapted_request, request = v1_chat_generate_request( request, tokenizer, tools_or_functions, tool_func_choice, return_text=False ) - output, error = await v1_chat_generate_completion( + # Prepare the parameters for generate_completion and generate_response functions + params = ChatCompletionParams( adapted_request=adapted_request, raw_request=raw_request, request=request, tokenizer=tokenizer, tokenizer_manager=tokenizer_manager, - backend=backend, + srt_backend=srt_backend, prompt_template=prompt_template, tools_or_functions=tools_or_functions, tool_func_choice=tool_func_choice, + frontend_state=None, # None first. Set later if needed + grammar_sampling=True if srt_backend else False, ) + + # Generate the text completion + output, error = await v1_chat_generate_completion(params) if error: return error + # If streaming, return the output(StreamingResponse) directly if adapted_request.stream: return output - response = v1_chat_generate_response( - adapted_request=adapted_request, - raw_request=raw_request, - request=request, - output_text=output, - prompt_template=prompt_template, - tokenizer=tokenizer, - tool_func_choice=tool_func_choice, - ) + # Generate the API response + response = v1_chat_generate_response(output_text=output, params=params) return response - - -def to_openai_style_logprobs( - input_token_logprobs=None, - output_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, -): - ret_logprobs = LogProbs() - - def append_token_logprobs(token_logprobs): - for logprob, _, token_text in token_logprobs: - ret_logprobs.tokens.append(token_text) - ret_logprobs.token_logprobs.append(logprob) - - # Not supported yet - ret_logprobs.text_offset.append(-1) - - def append_top_logprobs(top_logprobs): - for tokens in top_logprobs: - if tokens is not None: - ret_logprobs.top_logprobs.append( - {token[2]: token[0] for token in tokens} - ) - else: - ret_logprobs.top_logprobs.append(None) - - if input_token_logprobs is not None: - append_token_logprobs(input_token_logprobs) - if output_token_logprobs is not None: - append_token_logprobs(output_token_logprobs) - if input_top_logprobs is not None: - append_top_logprobs(input_top_logprobs) - if output_top_logprobs is not None: - append_top_logprobs(output_top_logprobs) - - return ret_logprobs From 9531f387d732e7b1e504bc1d64063f3be841a491 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Sat, 12 Oct 2024 07:34:09 +0000 Subject: [PATCH 27/40] fix --- functionary/sglang_inference.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 0ea6614..59be009 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -56,13 +56,6 @@ ) from functionary.prompt_template.prompt_utils import prepare_messages_for_inference - -class FileMetadata: - def __init__(self, filename: str, purpose: str): - self.filename = filename - self.purpose = purpose - - # Choices sampling method for sgl.select CHOICES_SAMPLING_METHOD = greedy_token_selection # Variable name for sgl frontend runtime generation From 9d1444ffbb0f0a22512471ad9542978e6589030e Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Mon, 14 Oct 2024 14:16:01 +0000 Subject: [PATCH 28/40] set up unittest --- functionary/openai_types.py | 2 + functionary/prompt_template/base_template.py | 1 - .../llama31_prompt_template.py | 27 +- .../prompt_template/llama3_prompt_template.py | 8 +- .../llama3_prompt_template_v3.py | 6 +- .../prompt_template/prompt_template_v2.py | 6 +- functionary/sglang_inference.py | 22 +- tests/test_sgl_server.py | 468 ++++++++++++++++++ 8 files changed, 510 insertions(+), 30 deletions(-) create mode 100644 tests/test_sgl_server.py diff --git a/functionary/openai_types.py b/functionary/openai_types.py index bbe1396..74325ce 100644 --- a/functionary/openai_types.py +++ b/functionary/openai_types.py @@ -97,11 +97,13 @@ class StreamChoice(BaseModel): finish_reason: Optional[str] = "stop" index: int = 0 + class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 + class ChatCompletionChunk(BaseModel): id: str object: str = "chat.completion.chunk" diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index 028560a..61618d6 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -241,7 +241,6 @@ def _update_gen_state_for_fn_call(self, gen_state: Dict, func_name: str): gen_state["func_name"] = func_name gen_state["func_index"] += 1 gen_state["call_id"] = prompt_utils.get_random_tool_call_id() - gen_state["first_time_func"] = True return gen_state diff --git a/functionary/prompt_template/llama31_prompt_template.py b/functionary/prompt_template/llama31_prompt_template.py index 11cdaa2..e235266 100644 --- a/functionary/prompt_template/llama31_prompt_template.py +++ b/functionary/prompt_template/llama31_prompt_template.py @@ -173,7 +173,7 @@ def stream_delta_text( if gen_state["stage"] in ["parameter", "code-interpreter"]: finish_reason = "tool_calls" return gen_state, prompt_utils.get_text_delta_response( - None, True, finish_reason + None, False, finish_reason ) responses = [] @@ -189,7 +189,7 @@ def stream_delta_text( gen_state["gen_empty_text"] = False responses.append( prompt_utils.get_text_delta_response( - gen_state["curr_text"], True, finish_reason + gen_state["curr_text"], False, finish_reason ) ) text_in_buffer = "".join(gen_state["text_to_func_buffer"] + [delta_text]) @@ -201,7 +201,7 @@ def stream_delta_text( delta_text_to_stream = gen_state["text_to_func_buffer"][0] responses.append( prompt_utils.get_text_delta_response( - delta_text_to_stream, True, finish_reason + delta_text_to_stream, False, finish_reason ) ) gen_state["text_to_func_buffer"] = gen_state["text_to_func_buffer"][ @@ -209,7 +209,7 @@ def stream_delta_text( ] responses.append( prompt_utils.get_text_delta_response( - delta_text, True, finish_reason + delta_text, False, finish_reason ) ) else: @@ -222,18 +222,23 @@ def stream_delta_text( gen_state, "", True, True, finish_reason ) ) - responses.append( - prompt_utils.get_function_delta_response( - gen_state, gen_state["curr_text"], True, True, finish_reason + if gen_state["curr_text"] != "": + responses.append( + prompt_utils.get_function_delta_response( + gen_state, + gen_state["curr_text"], + False, + False, + finish_reason, + ) ) - ) if " 0: responses.append( prompt_utils.get_function_delta_response( - gen_state, delta_args, True, True, finish_reason + gen_state, delta_args, False, False, finish_reason ) ) elif " 0: - delta = "" + if finish_reason is not None: + finish_reason = finish_reason["type"] + if len(delta) > 0: + delta = "" yield delta, finish_reason @@ -419,7 +421,7 @@ async def completion_stream_generator(params: ChatCompletionParams): result = ChatCompletionChunk( id=params.adapted_request.rid, choices=[chunk], model=params.request.model ) - chunk_dic = result.dict(exclude_unset=True) + chunk_dic = result.model_dump() chunk_data = json.dumps(chunk_dic, ensure_ascii=False) yield f"data: {chunk_data}\n\n" # Break from for loop after the first tool_call is streamed if functions is provided @@ -536,10 +538,14 @@ def v1_chat_generate_response( # Postprocess finish reason finish_reason = "stop" - if "function_call" in chat_mess and chat_mess["function_call"]: - finish_reason = "function_call" - if "tool_calls" in chat_mess and chat_mess["tool_calls"]: - finish_reason = "tool_calls" + if params.tool_func_choice is None or params.tool_func_choice in [ + "auto", + "required", + ]: + if "function_call" in chat_mess and chat_mess["function_call"]: + finish_reason = "function_call" + if "tool_calls" in chat_mess and chat_mess["tool_calls"]: + finish_reason = "tool_calls" choices = [ ChatCompletionResponseChoice( diff --git a/tests/test_sgl_server.py b/tests/test_sgl_server.py new file mode 100644 index 0000000..25200e3 --- /dev/null +++ b/tests/test_sgl_server.py @@ -0,0 +1,468 @@ +import json +import subprocess +import time +import unittest +from typing import Dict, List, Optional + +import psutil +import requests +from openai import OpenAI +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) +from rich import print + +DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 + + +def popen_launch_sgl_server( + model: str, + base_url: str, + timeout: float, + context_length: int, + grammar_sampling: bool, + env: Optional[dict] = None, + return_stdout_stderr: bool = False, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "server_sglang.py", + "--model", + model, + "--host", + host, + "--port", + str(port), + "--context-length", + str(context_length), + ] + if grammar_sampling: + command += ["--enable-grammar-sampling"] + + if return_stdout_stderr: + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + text=True, + ) + else: + process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + + start_time = time.time() + api_key = "test" + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {api_key}", + } + response = requests.get(f"{base_url}/health", headers=headers) + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(5) + raise TimeoutError("Server failed to start within the timeout period.") + + +def kill_child_process(pid, including_parent=True, skip_pid=None): + """Kill the process and all its children process.""" + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + children = parent.children(recursive=True) + for child in children: + if child.pid == skip_pid: + continue + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if including_parent: + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + +def call_openai_api( + test_case: Dict, + client: OpenAI, + model: str, + default_tools: List, + python_tool: Dict, + default_functions: List, + stream: bool = False, +): + if test_case["call_mode"] == "tools": + if test_case["code_interpreter"]: + if model.startswith("meetkai"): + tools = default_tools + [{"type": "code_interpreter"}] + else: + tools = default_tools + [python_tool] + else: + tools = default_tools + response = client.chat.completions.create( + model=model, + messages=test_case["messages"], + tools=tools, + tool_choice=test_case["choice"], + temperature=0.0, + stream=stream, + ) + else: + response = client.chat.completions.create( + model=model, + messages=test_case["messages"], + functions=default_functions, + function_call=test_case["choice"], + temperature=0.0, + stream=False, + ) + return response + + +class TestSglServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.openai_model = "gpt-4o-mini-2024-07-18" + cls.base_url = "http://127.0.0.1:8000" + cls.default_functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + ] + cls.default_tools = [{"type": "function", "function": cls.default_functions[0]}] + cls.python_tool = { + "type": "function", + "function": { + "name": "python", + "description": "Generate Python code", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The Python code to execute", + } + }, + "required": ["code"], + }, + }, + } + cls.request_handling_test_cases = [ + { + "test_aim": 'Normal text gen with "auto"', + "messages": [{"role": "user", "content": "How are you?"}], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Single tool_calls with "auto"', + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + # { + # "test_aim": 'Single function_call with "auto"', + # "messages": [ + # {"role": "user", "content": "What is the weather in Istanbul?"} + # ], + # "call_mode": "functions", + # "code_interpreter": False, + # "choice": "auto", + # }, + { + "test_aim": 'Parallel tool_calls with "auto"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul and Singapore respectively?", + } + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + # { + # "test_aim": 'Parallel function_calls with "auto"', + # "messages": [ + # { + # "role": "user", + # "content": "What is the weather in Istanbul and Singapore respectively?", + # } + # ], + # "call_mode": "functions", + # "code_interpreter": False, + # "choice": "auto", + # }, + # { + # "test_aim": 'Normal text gen + tool_calls with "auto"', + # "messages": [ + # { + # "role": "user", + # "content": "How are you? Can you also check what is the weather in Istanbul?", + # } + # ], + # "call_mode": "tools", + # "code_interpreter": False, + # "choice": "auto", + # }, + { + "test_aim": 'Normal text gen with "none"', + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "none", + }, + { + "test_aim": "tool_calls with tool_choice", + "messages": [{"role": "user", "content": "How are you?"}], + "call_mode": "tools", + "code_interpreter": False, + "choice": { + "type": "function", + "function": {"name": cls.default_functions[0]["name"]}, + }, + }, + # { + # "test_aim": "function_call with function_call", + # "messages": [{"role": "user", "content": "How are you?"}], + # "call_mode": "functions", + # "code_interpreter": False, + # "choice": {"name": cls.default_functions[0]["name"]}, + # }, + { + "test_aim": 'parallel tool_calls with "required"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul and Singapore respectively?", + } + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "required", + }, + # { + # "test_aim": 'code generation using "python" tool', + # "messages": [ + # { + # "role": "user", + # "content": "Use the Python tool to write a Python function that adds 2 integers.", + # } + # ], + # "call_mode": "tools", + # "code_interpreter": True, + # "choice": "auto", + # }, + # { + # "test_aim": 'Normal text generation (CoT) + code generation using "python" tool', + # "messages": [ + # { + # "role": "user", + # "content": "Write a Python function that adds 2 integers. Think step by step before generating code using the python tool.", + # } + # ], + # "call_mode": "tools", + # "code_interpreter": True, + # "choice": "auto", + # }, + ] + cls.client = OpenAI() + for i, test_case in enumerate(cls.request_handling_test_cases): + response = call_openai_api( + test_case=test_case, + client=cls.client, + model=cls.openai_model, + default_tools=cls.default_tools, + python_tool=cls.python_tool, + default_functions=cls.default_functions, + ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls and len(tool_calls) > 0: + for j in range(len(tool_calls)): + if tool_calls[j].function.name == "python": + response.choices[0].message.tool_calls[j].function.arguments = ( + json.loads(tool_calls[j].function.arguments)["code"] + ) + cls.request_handling_test_cases[i]["label"] = response + + response = call_openai_api( + test_case=test_case, + client=cls.client, + model=cls.openai_model, + default_tools=cls.default_tools, + python_tool=cls.python_tool, + default_functions=cls.default_functions, + stream=True, + ) + chunks = [chunk for chunk in response] + cls.request_handling_test_cases[i]["stream_label"] = chunks + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_child_process(cls.process.pid) + + def __init__(self, *args, **kwargs): + super(TestSglServer, self).__init__(*args, **kwargs) + self.served_models = [ + # "meetkai/functionary-small-v2.4", + # "meetkai/functionary-small-v2.5", + "meetkai/functionary-small-v3.1", + # "meetkai/functionary-small-v3.2", + ] + + def _check_nonstreaming_response(self, pred, label): + # Check if both label.id and pred.id start with the same prefix + assert pred.id.startswith(label.id[: label.id.index("-")]) + # Check if objects are equal + assert pred.object == label.object + pred_content = pred.choices[0].message.content + label_content = label.choices[0].message.content + pred_tool_calls = pred.choices[0].message.tool_calls + label_tool_calls = label.choices[0].message.tool_calls + pred_fn_call = pred.choices[0].message.function_call + label_fn_call = label.choices[0].message.function_call + # Check if content is equal + assert (pred_content is None) == (label_content is None) + # Check if tool_calls are equal + assert (pred_tool_calls is None) == (label_tool_calls is None) + if label_tool_calls is not None: + assert len(pred_tool_calls) == len(label_tool_calls) + for pred_tool_call, label_tool_call in zip( + pred_tool_calls, label_tool_calls + ): + assert isinstance(pred_tool_call, ChatCompletionMessageToolCall) + assert pred_tool_call.id.startswith( + "call_" + ) and label_tool_call.id.startswith("call_") + assert pred_tool_call.type == label_tool_call.type + assert pred_tool_call.function.name == label_tool_call.function.name + assert pred_tool_call.function.arguments is not None + # Check if function_calls are equal + assert (pred_fn_call is None) == (label_fn_call is None) + if label_fn_call is not None: + assert isinstance(pred_fn_call, FunctionCall) + assert pred_fn_call.name == label_fn_call.name + assert pred_fn_call.arguments is not None + # Check finish_reason + assert pred.choices[0].finish_reason == label.choices[0].finish_reason + + def _check_streaming_response(self, pred, label): + if sum([chunk.choices[0].delta.role == "assistant" for chunk in label]) > 1: + breakpoint() + tool_call_id = -1 + for i, chunk in enumerate(pred): + # Check if both label.id and pred.id start with the same prefix + assert chunk.id.startswith(label[0].id[: label[0].id.index("-")]) + # Check if objects are equal + assert chunk.object == label[0].object + # Check if the assistant turn is in the first chunk only + if i == 0: + assert chunk.choices[0].delta.role == "assistant" + else: + assert chunk.choices[0].delta.role is None + # Check if the finish_reason is in the last chunk only + if i == len(pred) - 1: + assert chunk.choices[0].finish_reason is not None + else: + assert chunk.choices[0].finish_reason is None + # Check if only one of content, function_call or tool_calls is not None + non_none_fields = [ + chunk.choices[0].delta.content is not None, + chunk.choices[0].delta.function_call is not None, + chunk.choices[0].delta.tool_calls is not None, + ] + if i == len(pred) - 1: + assert sum(non_none_fields) == 0 + else: + assert sum(non_none_fields) == 1 + # Check tool_calls + if chunk.choices[0].delta.tool_calls is not None: + call_type = chunk.choices[0].delta.tool_calls[0].type + name = chunk.choices[0].delta.tool_calls[0].function.name + args = chunk.choices[0].delta.tool_calls[0].function.arguments + # Check name, arguments, call_type and index + assert args is not None + if len(args) == 0: + assert name is not None + assert call_type == "function" + tool_call_id += 1 + assert chunk.choices[0].delta.tool_calls[0].index == tool_call_id + else: + assert name is None + assert call_type is None + # Function call seems bugged in OpenAI so not checking this + # Check function_call + # if chunk.choices[0].delta.function_call is not None: + # name = chunk.choices[0].delta.function_call.name + # args = chunk.choices[0].delta.function_call.arguments + + def test_sgl_server(self): + for model in self.served_models: + self.process = popen_launch_sgl_server( + model=model, + base_url=self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + context_length=4096, + grammar_sampling=False, + ) + + self.client = OpenAI(base_url=f"{self.base_url}/v1", api_key="test") + try: + for test_case in self.request_handling_test_cases: + pred = call_openai_api( + test_case=test_case, + client=self.client, + model=model, + default_tools=self.default_tools, + python_tool=self.python_tool, + default_functions=self.default_functions, + ) + label = test_case["label"] + self._check_nonstreaming_response(pred, label) + pred = call_openai_api( + test_case=test_case, + client=self.client, + model=model, + default_tools=self.default_tools, + python_tool=self.python_tool, + default_functions=self.default_functions, + stream=True, + ) + pred = [chunk for chunk in pred] + label = test_case["stream_label"] + self._check_streaming_response(pred, label) + except AssertionError: + raise + finally: + if self.process: + kill_child_process(self.process.pid) From a19f68febc12e4adbccc07774f41fe27cc102bab Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 15 Oct 2024 10:02:28 +0000 Subject: [PATCH 29/40] fixes --- .../llama31_prompt_template.py | 15 +- functionary/sglang_inference.py | 1 + tests/test_sgl_server.py | 144 +++++++++--------- 3 files changed, 81 insertions(+), 79 deletions(-) diff --git a/functionary/prompt_template/llama31_prompt_template.py b/functionary/prompt_template/llama31_prompt_template.py index e235266..9571749 100644 --- a/functionary/prompt_template/llama31_prompt_template.py +++ b/functionary/prompt_template/llama31_prompt_template.py @@ -148,8 +148,7 @@ def initialize_fsm_gen_state( "func_name": func_name, "func_index": -1, # index of the tool in tool_calls "call_id": None, # call_id of the current tool - "gen_empty_text": True, # if first_time we return an empty delta with role=assistant - "first_time_func": True, + "first_chunk": True, "text_to_func_buffer": [], "clear_buffer": False, "add_code_interpreter": add_code_interpreter, @@ -182,11 +181,11 @@ def stream_delta_text( ) if gen_state["stage"] == "text-gen": - if gen_state["gen_empty_text"]: + if gen_state["first_chunk"]: responses.append( prompt_utils.get_text_delta_response("", True, finish_reason) ) - gen_state["gen_empty_text"] = False + gen_state["first_chunk"] = False responses.append( prompt_utils.get_text_delta_response( gen_state["curr_text"], False, finish_reason @@ -215,8 +214,8 @@ def stream_delta_text( else: gen_state["text_to_func_buffer"].append(delta_text) elif gen_state["stage"] == "parameter": - if gen_state["first_time_func"]: - gen_state["first_time_func"] = False + if gen_state["first_chunk"]: + gen_state["first_chunk"] = False responses.append( prompt_utils.get_function_delta_response( gen_state, "", True, True, finish_reason @@ -256,8 +255,8 @@ def stream_delta_text( ) ) elif gen_state["stage"] == "code-interpreter": - if gen_state["first_time_func"]: - gen_state["first_time_func"] = False + if gen_state["first_chunk"]: + gen_state["first_chunk"] = False first_function_response = prompt_utils.get_function_delta_response( gen_state, "", True, True, finish_reason ) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 0e927a8..a3cd03d 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -176,6 +176,7 @@ def v1_chat_generate_request( "repetition_penalty": request.repetition_penalty, "regex": request.regex, "n": request.n, + "skip_special_tokens": False, } if isinstance(input_ids, str): diff --git a/tests/test_sgl_server.py b/tests/test_sgl_server.py index 25200e3..829ec7c 100644 --- a/tests/test_sgl_server.py +++ b/tests/test_sgl_server.py @@ -126,7 +126,7 @@ def call_openai_api( functions=default_functions, function_call=test_case["choice"], temperature=0.0, - stream=False, + stream=stream, ) return response @@ -187,15 +187,15 @@ def setUpClass(cls): "code_interpreter": False, "choice": "auto", }, - # { - # "test_aim": 'Single function_call with "auto"', - # "messages": [ - # {"role": "user", "content": "What is the weather in Istanbul?"} - # ], - # "call_mode": "functions", - # "code_interpreter": False, - # "choice": "auto", - # }, + { + "test_aim": 'Single function_call with "auto"', + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "call_mode": "functions", + "code_interpreter": False, + "choice": "auto", + }, { "test_aim": 'Parallel tool_calls with "auto"', "messages": [ @@ -208,30 +208,30 @@ def setUpClass(cls): "code_interpreter": False, "choice": "auto", }, - # { - # "test_aim": 'Parallel function_calls with "auto"', - # "messages": [ - # { - # "role": "user", - # "content": "What is the weather in Istanbul and Singapore respectively?", - # } - # ], - # "call_mode": "functions", - # "code_interpreter": False, - # "choice": "auto", - # }, - # { - # "test_aim": 'Normal text gen + tool_calls with "auto"', - # "messages": [ - # { - # "role": "user", - # "content": "How are you? Can you also check what is the weather in Istanbul?", - # } - # ], - # "call_mode": "tools", - # "code_interpreter": False, - # "choice": "auto", - # }, + { + "test_aim": 'Parallel function_calls with "auto"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul and Singapore respectively?", + } + ], + "call_mode": "functions", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Normal text gen + tool_calls with "auto"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul? Answer this question: 'How are you?', before checking the weather.", + } + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, { "test_aim": 'Normal text gen with "none"', "messages": [ @@ -251,13 +251,13 @@ def setUpClass(cls): "function": {"name": cls.default_functions[0]["name"]}, }, }, - # { - # "test_aim": "function_call with function_call", - # "messages": [{"role": "user", "content": "How are you?"}], - # "call_mode": "functions", - # "code_interpreter": False, - # "choice": {"name": cls.default_functions[0]["name"]}, - # }, + { + "test_aim": "function_call with function_call", + "messages": [{"role": "user", "content": "How are you?"}], + "call_mode": "functions", + "code_interpreter": False, + "choice": {"name": cls.default_functions[0]["name"]}, + }, { "test_aim": 'parallel tool_calls with "required"', "messages": [ @@ -270,30 +270,30 @@ def setUpClass(cls): "code_interpreter": False, "choice": "required", }, - # { - # "test_aim": 'code generation using "python" tool', - # "messages": [ - # { - # "role": "user", - # "content": "Use the Python tool to write a Python function that adds 2 integers.", - # } - # ], - # "call_mode": "tools", - # "code_interpreter": True, - # "choice": "auto", - # }, - # { - # "test_aim": 'Normal text generation (CoT) + code generation using "python" tool', - # "messages": [ - # { - # "role": "user", - # "content": "Write a Python function that adds 2 integers. Think step by step before generating code using the python tool.", - # } - # ], - # "call_mode": "tools", - # "code_interpreter": True, - # "choice": "auto", - # }, + { + "test_aim": 'code generation using "python" tool', + "messages": [ + { + "role": "user", + "content": "Use the Python tool to write a Python function that adds 2 integers.", + } + ], + "call_mode": "tools", + "code_interpreter": True, + "choice": "auto", + }, + { + "test_aim": 'Normal text generation (CoT) + code generation using "python" tool', + "messages": [ + { + "role": "user", + "content": "Write a Python function that adds 2 integers. Answer this question: 'How are you?', before using the python tool.", + } + ], + "call_mode": "tools", + "code_interpreter": True, + "choice": "auto", + }, ] cls.client = OpenAI() for i, test_case in enumerate(cls.request_handling_test_cases): @@ -377,8 +377,6 @@ def _check_nonstreaming_response(self, pred, label): assert pred.choices[0].finish_reason == label.choices[0].finish_reason def _check_streaming_response(self, pred, label): - if sum([chunk.choices[0].delta.role == "assistant" for chunk in label]) > 1: - breakpoint() tool_call_id = -1 for i, chunk in enumerate(pred): # Check if both label.id and pred.id start with the same prefix @@ -420,11 +418,15 @@ def _check_streaming_response(self, pred, label): else: assert name is None assert call_type is None - # Function call seems bugged in OpenAI so not checking this # Check function_call - # if chunk.choices[0].delta.function_call is not None: - # name = chunk.choices[0].delta.function_call.name - # args = chunk.choices[0].delta.function_call.arguments + if chunk.choices[0].delta.function_call is not None: + name = chunk.choices[0].delta.function_call.name + args = chunk.choices[0].delta.function_call.arguments + assert args is not None + if len(args) == 0: + assert name is not None + else: + assert name is None def test_sgl_server(self): for model in self.served_models: From 930cfc25011fc91a6211b7592e94ee0e7b9c280f Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 15 Oct 2024 10:51:00 +0000 Subject: [PATCH 30/40] pass unittest for v3 template regular generation --- .../prompt_template/llama3_prompt_template_v3.py | 13 ++++++------- tests/test_sgl_server.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/functionary/prompt_template/llama3_prompt_template_v3.py b/functionary/prompt_template/llama3_prompt_template_v3.py index 6cea27b..7d0dfd6 100644 --- a/functionary/prompt_template/llama3_prompt_template_v3.py +++ b/functionary/prompt_template/llama3_prompt_template_v3.py @@ -212,8 +212,7 @@ def initialize_fsm_gen_state( "func_name": func_name, "func_index": -1, # index of the tool in tool_calls "call_id": None, # call_id of the current tool - "gen_empty_text": True, # if first_time we return an empty delta with role=assistant - "first_time_func": True, + "first_chunk": True, "add_all_recipient": add_all_recipient, "add_code_interpreter": add_code_interpreter, } @@ -247,11 +246,11 @@ def stream_delta_text( if gen_state["stage"] == "text-gen": if delta_text != self.function_separator: - if gen_state["gen_empty_text"]: + if gen_state["first_chunk"]: responses.append( prompt_utils.get_text_delta_response("", True, finish_reason) ) - gen_state["gen_empty_text"] = False + gen_state["first_chunk"] = False responses.append( prompt_utils.get_text_delta_response( delta_text, False, finish_reason @@ -259,16 +258,16 @@ def stream_delta_text( ) elif gen_state["stage"] in ["parameter", "code-interpreter"]: if delta_text != self.function_separator: - if gen_state["first_time_func"]: + if gen_state["first_chunk"]: responses.append( prompt_utils.get_function_delta_response( gen_state, "", True, True, finish_reason ) ) - gen_state["first_time_func"] = False + gen_state["first_chunk"] = False responses.append( prompt_utils.get_function_delta_response( - gen_state, delta_text, True, False, finish_reason + gen_state, delta_text, False, False, finish_reason ) ) diff --git a/tests/test_sgl_server.py b/tests/test_sgl_server.py index 829ec7c..628e036 100644 --- a/tests/test_sgl_server.py +++ b/tests/test_sgl_server.py @@ -337,7 +337,7 @@ def __init__(self, *args, **kwargs): # "meetkai/functionary-small-v2.4", # "meetkai/functionary-small-v2.5", "meetkai/functionary-small-v3.1", - # "meetkai/functionary-small-v3.2", + "meetkai/functionary-small-v3.2", ] def _check_nonstreaming_response(self, pred, label): From e45656361a147d58f8908b7b1813fe649710f04a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 16 Oct 2024 04:19:34 +0000 Subject: [PATCH 31/40] fix --- functionary/prompt_template/base_template.py | 1 + .../prompt_template/llama31_prompt_template.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index 61618d6..f7898e6 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -241,6 +241,7 @@ def _update_gen_state_for_fn_call(self, gen_state: Dict, func_name: str): gen_state["func_name"] = func_name gen_state["func_index"] += 1 gen_state["call_id"] = prompt_utils.get_random_tool_call_id() + gen_state["first_function_chunk"] = True return gen_state diff --git a/functionary/prompt_template/llama31_prompt_template.py b/functionary/prompt_template/llama31_prompt_template.py index 9571749..81171c0 100644 --- a/functionary/prompt_template/llama31_prompt_template.py +++ b/functionary/prompt_template/llama31_prompt_template.py @@ -149,6 +149,7 @@ def initialize_fsm_gen_state( "func_index": -1, # index of the tool in tool_calls "call_id": None, # call_id of the current tool "first_chunk": True, + "first_function_chunk": True, "text_to_func_buffer": [], "clear_buffer": False, "add_code_interpreter": add_code_interpreter, @@ -214,13 +215,14 @@ def stream_delta_text( else: gen_state["text_to_func_buffer"].append(delta_text) elif gen_state["stage"] == "parameter": - if gen_state["first_chunk"]: - gen_state["first_chunk"] = False + if gen_state["first_function_chunk"]: responses.append( prompt_utils.get_function_delta_response( - gen_state, "", True, True, finish_reason + gen_state, "", True, gen_state["first_chunk"], finish_reason ) ) + gen_state["first_chunk"] = False + gen_state["first_function_chunk"] = False if gen_state["curr_text"] != "": responses.append( prompt_utils.get_function_delta_response( @@ -255,11 +257,12 @@ def stream_delta_text( ) ) elif gen_state["stage"] == "code-interpreter": - if gen_state["first_chunk"]: - gen_state["first_chunk"] = False + if gen_state["first_function_chunk"]: first_function_response = prompt_utils.get_function_delta_response( - gen_state, "", True, True, finish_reason + gen_state, "", True, gen_state["first_chunk"], finish_reason ) + gen_state["first_chunk"] = False + gen_state["first_function_chunk"] = False responses.append(first_function_response) responses.append( prompt_utils.get_function_delta_response( From 11d2de1f04a57c01f39365bfa86c4fa21e0048f1 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 16 Oct 2024 09:12:43 +0000 Subject: [PATCH 32/40] fix --- .../llama3_prompt_template_v3.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/functionary/prompt_template/llama3_prompt_template_v3.py b/functionary/prompt_template/llama3_prompt_template_v3.py index 7d0dfd6..46569ca 100644 --- a/functionary/prompt_template/llama3_prompt_template_v3.py +++ b/functionary/prompt_template/llama3_prompt_template_v3.py @@ -213,6 +213,7 @@ def initialize_fsm_gen_state( "func_index": -1, # index of the tool in tool_calls "call_id": None, # call_id of the current tool "first_chunk": True, + "first_function_chunk": True, "add_all_recipient": add_all_recipient, "add_code_interpreter": add_code_interpreter, } @@ -251,6 +252,12 @@ def stream_delta_text( prompt_utils.get_text_delta_response("", True, finish_reason) ) gen_state["first_chunk"] = False + if gen_state["curr_text"] != "": + responses.append( + prompt_utils.get_text_delta_response( + gen_state["curr_text"], False, finish_reason + ) + ) responses.append( prompt_utils.get_text_delta_response( delta_text, False, finish_reason @@ -258,13 +265,25 @@ def stream_delta_text( ) elif gen_state["stage"] in ["parameter", "code-interpreter"]: if delta_text != self.function_separator: - if gen_state["first_chunk"]: + if gen_state["first_function_chunk"]: responses.append( prompt_utils.get_function_delta_response( - gen_state, "", True, True, finish_reason + gen_state, "", True, gen_state["first_chunk"], finish_reason ) ) gen_state["first_chunk"] = False + gen_state["first_function_chunk"] = False + + if gen_state["curr_text"] != "": + responses.append( + prompt_utils.get_function_delta_response( + gen_state, + gen_state["curr_text"], + False, + False, + finish_reason, + ) + ) responses.append( prompt_utils.get_function_delta_response( gen_state, delta_text, False, False, finish_reason @@ -299,9 +318,13 @@ def update_fsm_gen_state( # v2: "{func_name}\n{param_names}\n<|from|> assistant\n<|recipient|>" if gen_state["stage"] == "pre-function": # Check if the new state is in "function" stage - if gen_state["curr_text"].endswith(self.get_start_of_function_call_token()): + if self.get_start_of_function_call_token() in gen_state["curr_text"]: gen_state["stage"] = "function" + curr_text = gen_state["curr_text"] gen_state = self._reset_fsm_curr_text_and_tokens(gen_state=gen_state) + gen_state["curr_text"] = curr_text.removeprefix( + self.get_start_of_function_call_token() + ) gen_state["func_name"] = "" elif gen_state["stage"] == "function": @@ -324,16 +347,21 @@ def update_fsm_gen_state( # Use the suffix from curr_text as the prefix in "pre-parameter" tool_name = options[options_mask.index(True)] suffix = curr_text[len(tool_name) :] - gen_state = self._update_gen_state_for_fn_call( - gen_state=gen_state, func_name=tool_name - ) + if tool_name == "all": + gen_state["func_name"] = tool_name + else: + gen_state = self._update_gen_state_for_fn_call( + gen_state=gen_state, func_name=tool_name + ) gen_state = self._reset_fsm_curr_text_and_tokens(gen_state=gen_state) # Jump to "parameter" stage if suffix is "\n" gen_state["stage"] = "pre-parameter" if suffix == "" else "parameter" elif gen_state["stage"] == "pre-parameter": if self.fn_param_sep_token in gen_state["curr_text"]: + curr_text = gen_state["curr_text"] gen_state = self._reset_fsm_curr_text_and_tokens(gen_state=gen_state) + gen_state["curr_text"] = curr_text.removeprefix(self.fn_param_sep_token) # Check if the new state is "text-gen" or "code-interpreter" or "parameter" if gen_state["func_name"] == "all": gen_state["stage"] = "text-gen" From d8fdff498045d452665b0e5226f5151a632fd91c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 16 Oct 2024 12:50:36 +0000 Subject: [PATCH 33/40] fix v2.5 --- .../prompt_template/llama3_prompt_template.py | 34 ++++++++++++------- tests/test_sgl_server.py | 9 ++++- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/functionary/prompt_template/llama3_prompt_template.py b/functionary/prompt_template/llama3_prompt_template.py index f7412ea..888f40a 100644 --- a/functionary/prompt_template/llama3_prompt_template.py +++ b/functionary/prompt_template/llama3_prompt_template.py @@ -193,18 +193,24 @@ def initialize_fsm_gen_state( else: stage = "pre-function" - return { + gen_state = { "stage": stage, "curr_tokens": curr_tokens, "curr_text": curr_text, "func_name": func_name, "func_index": -1, # index of the tool in tool_calls "call_id": None, # call_id of the current tool - "gen_empty_text": True, # if first_time we return an empty delta with role=assistant - "first_time_func": True, + "first_chunk": True, + "first_function_chunk": True, "add_code_interpreter": add_code_interpreter, } + return ( + self._update_gen_state_for_fn_call(gen_state, func_name) + if func_name is not None + else gen_state + ) + def stream_delta_text( self, gen_state: Dict, @@ -228,16 +234,17 @@ def stream_delta_text( ) if gen_state["stage"] == "text-gen": - if gen_state["gen_empty_text"]: + if gen_state["first_chunk"]: responses.append( prompt_utils.get_text_delta_response("", True, finish_reason) ) - gen_state["gen_empty_text"] = False - responses.append( - prompt_utils.get_text_delta_response( - gen_state["curr_text"], False, finish_reason + gen_state["first_chunk"] = False + if gen_state["curr_text"] != "": + responses.append( + prompt_utils.get_text_delta_response( + gen_state["curr_text"], False, finish_reason + ) ) - ) if delta_text != self.function_separator: responses.append( prompt_utils.get_text_delta_response( @@ -245,16 +252,17 @@ def stream_delta_text( ) ) elif gen_state["stage"] in ["parameter", "code-interpreter"]: - if gen_state["first_time_func"]: + if gen_state["first_function_chunk"]: responses.append( prompt_utils.get_function_delta_response( - gen_state, "", True, True, finish_reason + gen_state, "", True, gen_state["first_chunk"], finish_reason ) ) - gen_state["first_time_func"] = False + gen_state["first_chunk"] = False + gen_state["first_function_chunk"] = False responses.append( prompt_utils.get_function_delta_response( - gen_state, delta_text, True, False, finish_reason + gen_state, delta_text, False, False, finish_reason ) ) diff --git a/tests/test_sgl_server.py b/tests/test_sgl_server.py index 628e036..a63e011 100644 --- a/tests/test_sgl_server.py +++ b/tests/test_sgl_server.py @@ -335,7 +335,7 @@ def __init__(self, *args, **kwargs): super(TestSglServer, self).__init__(*args, **kwargs) self.served_models = [ # "meetkai/functionary-small-v2.4", - # "meetkai/functionary-small-v2.5", + "meetkai/functionary-small-v2.5", "meetkai/functionary-small-v3.1", "meetkai/functionary-small-v3.2", ] @@ -441,6 +441,13 @@ def test_sgl_server(self): self.client = OpenAI(base_url=f"{self.base_url}/v1", api_key="test") try: for test_case in self.request_handling_test_cases: + # v2.5 cannot generate this case + if ( + model == "meetkai/functionary-small-v2.5" + and test_case["test_aim"] + == 'Normal text generation (CoT) + code generation using "python" tool' + ): + continue pred = call_openai_api( test_case=test_case, client=self.client, From 687ae300375be5fd1a0bd8b9f2ad38756392c935 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Mon, 21 Oct 2024 14:07:51 +0000 Subject: [PATCH 34/40] fixes to both sgl and vllm --- .../prompt_template/prompt_template_v2.py | 45 +- functionary/vllm_inference.py | 25 +- tests/test_server.py | 691 ++++++++++++++++++ tests/test_sgl_server.py | 477 ------------ 4 files changed, 732 insertions(+), 506 deletions(-) create mode 100644 tests/test_server.py delete mode 100644 tests/test_sgl_server.py diff --git a/functionary/prompt_template/prompt_template_v2.py b/functionary/prompt_template/prompt_template_v2.py index d48089f..711f9d6 100644 --- a/functionary/prompt_template/prompt_template_v2.py +++ b/functionary/prompt_template/prompt_template_v2.py @@ -236,20 +236,26 @@ def initialize_fsm_gen_state( add_all_recipient = True stage = "function" - return { + gen_state = { "stage": stage, "curr_tokens": curr_tokens, "curr_text": curr_text, "func_name": func_name, "func_index": -1, # index of the tool in tool_calls "call_id": None, # call_id of the current tool - "gen_empty_text": True, # if first_time we return an empty delta with role=assistant - "first_time_func": True, + "first_chunk": True, + "first_function_chunk": True, "prev_newline": False, "add_all_recipient": add_all_recipient, "add_code_interpreter": add_code_interpreter, } + return ( + self._update_gen_state_for_fn_call(gen_state, func_name) + if func_name is not None + else gen_state + ) + def stream_delta_text( self, gen_state: Dict, @@ -277,33 +283,50 @@ def stream_delta_text( gen_state["prev_newline"] = True elif gen_state["prev_newline"] and delta_text != self.from_token: responses.append( - prompt_utils.get_text_delta_response("\n", True, finish_reason) + prompt_utils.get_text_delta_response("\n", False, finish_reason) ) gen_state["prev_newline"] = False elif gen_state["prev_newline"] is False: - if gen_state["gen_empty_text"]: + if gen_state["first_chunk"]: responses.append( prompt_utils.get_text_delta_response("", True, finish_reason) ) - gen_state["gen_empty_text"] = False + gen_state["first_chunk"] = False + if gen_state["curr_text"] != "": + responses.append( + prompt_utils.get_text_delta_response( + gen_state["curr_text"], False, finish_reason + ) + ) delta_text = delta_text.lstrip(" ") responses.append( prompt_utils.get_text_delta_response( delta_text, False, finish_reason ) ) - elif gen_state["stage"] == "parameter": - if gen_state["first_time_func"]: + elif gen_state["stage"] in ["parameter", "code-interpreter"]: + if gen_state["first_function_chunk"]: responses.append( prompt_utils.get_function_delta_response( - gen_state, "", True, True, finish_reason + gen_state, "", True, gen_state["first_chunk"], finish_reason ) ) - gen_state["first_time_func"] = False + gen_state["first_chunk"] = False + gen_state["first_function_chunk"] = False + if gen_state["curr_text"] != "": + responses.append( + prompt_utils.get_function_delta_response( + gen_state, + gen_state["curr_text"], + False, + False, + finish_reason, + ) + ) delta_text = delta_text.lstrip(" ") responses.append( prompt_utils.get_function_delta_response( - gen_state, delta_text, True, False, finish_reason + gen_state, delta_text, False, False, finish_reason ) ) diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index 1f6a876..f0d9502 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -167,7 +167,7 @@ async def process_chat_completion( return error_check_ret model_name = request.model - request_id = f"cmpl-{random_uuid()}" + request_id = f"chatcmpl-{random_uuid()}" created_time = int(time.time()) # compute stop_token_ids @@ -291,22 +291,11 @@ async def completion_stream_generator( if response["finish_reason"] == "function_call": response["finish_reason"] = "tool_calls" - # Workaround Fixes - response["delta"]["role"] = "assistant" - if ( - "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - for tool_call in response["delta"]["tool_calls"]: - if tool_call.get("type") is None: - tool_call["type"] = "function" - chunk = StreamChoice(**response) result = ChatCompletionChunk( id=request_id, choices=[chunk], model=model_name ) - chunk_dic = result.dict(exclude_unset=True) + chunk_dic = result.model_dump() chunk_data = json.dumps(chunk_dic, ensure_ascii=False) yield f"data: {chunk_data}\n\n" # Break from for loop after the first tool_call is streamed if functions is provided @@ -360,11 +349,11 @@ async def completion_stream_generator( chat_mess["tool_calls"] = None # Postprocess finish reason - if "function_call" in chat_mess and chat_mess["function_call"]: - output.finish_reason = "function_call" - - if "tool_calls" in chat_mess and chat_mess["tool_calls"]: - output.finish_reason = "tool_calls" + if tool_func_choice is None or tool_func_choice in ["auto", "required"]: + if "function_call" in chat_mess and chat_mess["function_call"]: + output.finish_reason = "function_call" + if "tool_calls" in chat_mess and chat_mess["tool_calls"]: + output.finish_reason = "tool_calls" # Convert v1 from function_call to tool_calls if tools are provided instead of functions if ( diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..75b2526 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,691 @@ +import json +import subprocess +import time +import unittest +from typing import Dict, List, Literal, Optional + +import psutil +import requests +from openai import OpenAI +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) + +DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 + + +def popen_launch_server( + backend: Literal["vllm", "sglang"], + model: str, + base_url: str, + timeout: float, + context_length: int, + grammar_sampling: bool, + env: Optional[dict] = None, + return_stdout_stderr: bool = False, +) -> subprocess.Popen: + """ + Launch a server process with specified backend and configuration. + + Args: + backend (Literal["vllm", "sglang"]): The backend to use for the server. + model (str): The model name to be used. + base_url (str): The base URL for the server. + timeout (float): Maximum time to wait for server launch. + context_length (int): The context length for the model. + grammar_sampling (bool): Whether to enable grammar sampling. + env (Optional[dict]): Environment variables for the subprocess. Defaults to None. + return_stdout_stderr (bool): Whether to capture and return stdout/stderr. Defaults to False. + + Returns: + subprocess.Popen: The launched server process. + + Raises: + TimeoutError: If the server fails to start within the specified timeout period. + """ + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + f"server_{backend}.py", + "--model", + model, + "--host", + host, + "--port", + str(port), + ] + if backend == "vllm": + command += ["--max-model-len", str(context_length)] + else: + command += ["--context-length", str(context_length)] + if grammar_sampling: + command += ["--enable-grammar-sampling"] + + if return_stdout_stderr: + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + text=True, + ) + else: + process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + + start_time = time.time() + api_key = "test" + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {api_key}", + } + response = requests.get(f"{base_url}/health", headers=headers) + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(5) + raise TimeoutError("Server failed to start within the timeout period.") + + +def kill_child_process(pid, including_parent=True, skip_pid=None) -> None: + """ + Kill the process and all its children processes. + + Args: + pid (int): The process ID of the parent process to kill. + including_parent (bool, optional): If True, kill the parent process as well. Defaults to True. + skip_pid (int, optional): Process ID to skip killing. Defaults to None. + + Returns: + None + + Raises: + psutil.NoSuchProcess: If the specified process does not exist. + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + children = parent.children(recursive=True) + for child in children: + if child.pid == skip_pid: + continue + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if including_parent: + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + +def call_openai_api( + test_case: Dict, + client: OpenAI, + messages: List[Dict], + model: str, + default_tools: List, + python_tool: Dict, + default_functions: List, + stream: bool = False, +) -> ChatCompletion: + """ + Call the OpenAI API with the given parameters. + + Args: + test_case (Dict): A dictionary containing test case information. + client (OpenAI): An instance of the OpenAI client. + messages (List[Dict]): A list of message dictionaries to be sent to the API. + model (str): The name of the model to use for the API call. + default_tools (List): A list of default tools to be used in the API call. + python_tool (Dict): A dictionary representing the Python tool configuration. + default_functions (List): A list of default functions to be used in the API call. + stream (bool, optional): Whether to stream the response. Defaults to False. + + Returns: + The response from the OpenAI API. + """ + if test_case["call_mode"] == "tools": + if test_case["code_interpreter"]: + if model.startswith("meetkai"): + tools = default_tools + [{"type": "code_interpreter"}] + else: + tools = default_tools + [python_tool] + else: + tools = default_tools + response = client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + tool_choice=test_case["choice"], + temperature=0.0, + stream=stream, + ) + else: + response = client.chat.completions.create( + model=model, + messages=messages, + functions=default_functions, + function_call=test_case["choice"], + temperature=0.0, + stream=stream, + ) + return response + + +def check_nonstreaming_response(pred: ChatCompletion, label: ChatCompletion) -> None: + """ + Check if the non-streaming response matches the expected label. + + This function compares various attributes of the predicted response (pred) with the + expected label response (label) to ensure they match. It checks the following: + - The ID prefix + - The object type + - The content of the message + - The tool calls (if any) + - The function call (if any) + - The finish reason + + Args: + pred (ChatCompletion): The predicted response from the API. + label (ChatCompletion): The expected (label) response to compare against. + + Raises: + AssertionError: If any of the checks fail, indicating a mismatch between + the predicted and expected responses. + """ + # Check if both label.id and pred.id start with the same prefix + assert pred.id.startswith(label.id[: label.id.index("-")]) + # Check if objects are equal + assert pred.object == label.object + pred_content = pred.choices[0].message.content + label_content = label.choices[0].message.content + pred_tool_calls = pred.choices[0].message.tool_calls + label_tool_calls = label.choices[0].message.tool_calls + pred_fn_call = pred.choices[0].message.function_call + label_fn_call = label.choices[0].message.function_call + # Check if content is equal + assert (pred_content is None) == (label_content is None) + # Check if tool_calls are equal + assert (pred_tool_calls is None) == (label_tool_calls is None) + if label_tool_calls is not None: + assert len(pred_tool_calls) == len(label_tool_calls) + for pred_tool_call, label_tool_call in zip(pred_tool_calls, label_tool_calls): + assert isinstance(pred_tool_call, ChatCompletionMessageToolCall) + assert pred_tool_call.id.startswith( + "call_" + ) and label_tool_call.id.startswith("call_") + assert pred_tool_call.type == label_tool_call.type + assert pred_tool_call.function.name == label_tool_call.function.name + assert pred_tool_call.function.arguments is not None + # Check if function_calls are equal + assert (pred_fn_call is None) == (label_fn_call is None) + if label_fn_call is not None: + assert isinstance(pred_fn_call, FunctionCall) + assert pred_fn_call.name == label_fn_call.name + assert pred_fn_call.arguments is not None + # Check finish_reason + assert pred.choices[0].finish_reason == label.choices[0].finish_reason + + +def check_streaming_response( + pred: List[ChatCompletionChunk], label: List[ChatCompletionChunk] +) -> None: + """ + Check the streaming response from the API against the expected label. + + This function compares a list of predicted ChatCompletionChunk objects + against a list of label ChatCompletionChunk objects. It verifies various + aspects of the streaming response, including: + - The consistency of chunk IDs + - The presence and correctness of content, function calls, and tool calls + - The proper structure and timing of chunks in the stream + + Args: + pred (List[ChatCompletionChunk]): The list of predicted chunks from the API. + label (List[ChatCompletionChunk]): The list of expected (label) chunks to compare against. + + Raises: + AssertionError: If any of the checks fail, indicating a mismatch between + the predicted and expected streaming responses. + """ + has_content, has_fn_call = False, False + num_tool_calls = 0 + if label[0].choices[0].delta.content and len(label[0].choices[0].delta.content) > 0: + has_content = True + for chunk in label: + if ( + chunk.choices[0].delta.tool_calls is not None + and len(chunk.choices[0].delta.tool_calls) > 0 + and chunk.choices[0].delta.tool_calls[0].function.name is not None + ): + num_tool_calls += 1 + if chunk.choices[0].delta.function_call is not None: + has_fn_call = True + tool_call_id = -1 + pred_has_content, pred_fn_call = False, False + for i, chunk in enumerate(pred): + # Check if both label.id and pred.id start with the same prefix + assert chunk.id.startswith(label[0].id[: label[0].id.index("-")]) + # Check if objects are equal + assert chunk.object == label[0].object + # Check if the assistant turn is in the first chunk only + if i == 0: + assert chunk.choices[0].delta.role == "assistant" + if ( + chunk.choices[0].delta.content + and len(chunk.choices[0].delta.content) > 0 + ): + pred_has_content = True + else: + assert chunk.choices[0].delta.role is None + # Check if the finish_reason is in the last chunk only + if i == len(pred) - 1: + assert chunk.choices[0].finish_reason is not None + else: + assert chunk.choices[0].finish_reason is None + # Check if only one of content, function_call or tool_calls is not None + non_none_fields = [ + chunk.choices[0].delta.content is not None, + chunk.choices[0].delta.function_call is not None, + chunk.choices[0].delta.tool_calls is not None, + ] + if i == len(pred) - 1: + assert sum(non_none_fields) == 0 + else: + assert sum(non_none_fields) == 1 + # Check tool_calls + if chunk.choices[0].delta.tool_calls is not None: + call_type = chunk.choices[0].delta.tool_calls[0].type + name = chunk.choices[0].delta.tool_calls[0].function.name + args = chunk.choices[0].delta.tool_calls[0].function.arguments + # Check name, arguments, call_type and index + assert args is not None + if len(args) == 0: + assert name is not None + assert call_type == "function" + tool_call_id += 1 + assert chunk.choices[0].delta.tool_calls[0].index == tool_call_id + else: + assert name is None + assert call_type is None + # Check function_call + if chunk.choices[0].delta.function_call is not None: + pred_fn_call = True + name = chunk.choices[0].delta.function_call.name + args = chunk.choices[0].delta.function_call.arguments + assert args is not None + if len(args) == 0: + assert name is not None + else: + assert name is None + # Check if pred has same tool_calls and function_call as label + assert pred_has_content == has_content + assert pred_fn_call == has_fn_call + assert num_tool_calls == tool_call_id + 1 + + +class TestServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.openai_model = "gpt-4o-mini-2024-07-18" + cls.base_url = "http://127.0.0.1:8000" + cls.default_functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + ] + cls.default_tools = [{"type": "function", "function": cls.default_functions[0]}] + cls.python_tool = { + "type": "function", + "function": { + "name": "python", + "description": "Generate Python code", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The Python code to execute", + } + }, + "required": ["code"], + }, + }, + } + cls.request_handling_test_cases = [ + { + "test_aim": 'Normal text gen with "auto"', + "messages": [{"role": "user", "content": "How are you?"}], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Single tool_calls with "auto"', + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Single function_call with "auto"', + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "call_mode": "functions", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Parallel tool_calls with "auto"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul and Singapore respectively?", + } + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Parallel function_calls with "auto"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul and Singapore respectively?", + } + ], + "call_mode": "functions", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Normal text gen + tool_calls with "auto"', + "messages": { + "openai": [ + { + "role": "user", + "content": "What is the weather in Istanbul? Answer this question: 'How are you?', before checking the weather.", + } + ], + "meetkai/functionary-small-v2.4": [ + { + "role": "user", + "content": "Answer both these questions: 'How are you?' and 'What's the weather in Istanbul?'", + } + ], + "meetkai/functionary-small-v2.5": [ + { + "role": "user", + "content": "What is the weather in Istanbul? Answer this question: 'How are you?', before checking the weather.", + } + ], + "meetkai/functionary-small-v3.1": [ + { + "role": "user", + "content": "Answer both these questions: 'How are you?' and 'What's the weather in Istanbul?'", + } + ], + "meetkai/functionary-small-v3.2": [ + { + "role": "user", + "content": "What is the weather in Istanbul? Answer this question: 'How are you?', before checking the weather.", + } + ], + }, + "call_mode": "tools", + "code_interpreter": False, + "choice": "auto", + }, + { + "test_aim": 'Normal text gen with "none"', + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "none", + }, + { + "test_aim": "tool_calls with tool_choice", + "messages": [{"role": "user", "content": "How are you?"}], + "call_mode": "tools", + "code_interpreter": False, + "choice": { + "type": "function", + "function": {"name": cls.default_functions[0]["name"]}, + }, + }, + { + "test_aim": "function_call with function_call", + "messages": [{"role": "user", "content": "How are you?"}], + "call_mode": "functions", + "code_interpreter": False, + "choice": {"name": cls.default_functions[0]["name"]}, + }, + { + "test_aim": 'parallel tool_calls with "required"', + "messages": [ + { + "role": "user", + "content": "What is the weather in Istanbul and Singapore respectively?", + } + ], + "call_mode": "tools", + "code_interpreter": False, + "choice": "required", + }, + { + "test_aim": 'code generation using "python" tool', + "messages": [ + { + "role": "user", + "content": "Use the Python tool to write a Python function that adds 2 integers.", + } + ], + "call_mode": "tools", + "code_interpreter": True, + "choice": "auto", + }, + { + "test_aim": 'Normal text generation (CoT) + code generation using "python" tool', + "messages": { + "openai": [ + { + "role": "user", + "content": "Write a Python function that adds 2 integers. Answer this question: 'How are you?', before using the python tool.", + } + ], + "meetkai/functionary-small-v2.4": [ + { + "role": "user", + "content": "Write a Python function that adds 2 integers. Answer this question: 'How are you?', before using the python tool.", + } + ], + "meetkai/functionary-small-v2.5": [ + { + "role": "user", + "content": "Answer both these questions: 'How are you?' and 'Write a Python function that adds 2 integers.'", + } + ], + "meetkai/functionary-small-v3.1": [ + { + "role": "user", + "content": "Write a Python function that adds 2 integers. Answer this question: 'How are you?', before using the python tool.", + } + ], + "meetkai/functionary-small-v3.2": [ + { + "role": "user", + "content": "Write a Python function that adds 2 integers. Answer this question: 'How are you?', before using the python tool.", + } + ], + }, + "call_mode": "tools", + "code_interpreter": True, + "choice": "auto", + }, + ] + + # Get the labels + cls.client = OpenAI() + for i, test_case in enumerate(cls.request_handling_test_cases): + if isinstance(test_case["messages"], dict): + messages = test_case["messages"]["openai"] + else: + messages = test_case["messages"] + response = call_openai_api( + test_case=test_case, + client=cls.client, + messages=messages, + model=cls.openai_model, + default_tools=cls.default_tools, + python_tool=cls.python_tool, + default_functions=cls.default_functions, + ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls and len(tool_calls) > 0: + for j in range(len(tool_calls)): + if tool_calls[j].function.name == "python": + response.choices[0].message.tool_calls[j].function.arguments = ( + json.loads(tool_calls[j].function.arguments)["code"] + ) + cls.request_handling_test_cases[i]["label"] = response + + response = call_openai_api( + test_case=test_case, + client=cls.client, + messages=messages, + model=cls.openai_model, + default_tools=cls.default_tools, + python_tool=cls.python_tool, + default_functions=cls.default_functions, + stream=True, + ) + chunks = [chunk for chunk in response] + cls.request_handling_test_cases[i]["stream_label"] = chunks + + # Point the client towards the local server before running the tests + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="test") + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_child_process(cls.process.pid) + + def __init__(self, *args, **kwargs): + super(TestServer, self).__init__(*args, **kwargs) + self.served_models = [ + "meetkai/functionary-small-v2.4", + "meetkai/functionary-small-v2.5", + "meetkai/functionary-small-v3.1", + "meetkai/functionary-small-v3.2", + ] + + def _evaluate_test_cases(self, model: str) -> None: + """ + Evaluate test cases for a given model. + + This method runs through all the test cases in self.request_handling_test_cases, + making API calls to both streaming and non-streaming endpoints. It compares + the responses against pre-computed labels to ensure correctness. + + Args: + model (str): The name of the model to evaluate. + + Raises: + AssertionError: If any test case fails, with a message indicating which + test case failed. + + Note: + This method will kill the server process after evaluation, regardless + of whether the tests pass or fail. + """ + try: + for test_case in self.request_handling_test_cases: + if isinstance(test_case["messages"], dict): + messages = test_case["messages"][model] + else: + messages = test_case["messages"] + # Check non-streaming + pred = call_openai_api( + test_case=test_case, + client=self.client, + messages=messages, + model=model, + default_tools=self.default_tools, + python_tool=self.python_tool, + default_functions=self.default_functions, + ) + label = test_case["label"] + check_nonstreaming_response(pred, label) + # Check streaming + pred = call_openai_api( + test_case=test_case, + client=self.client, + messages=messages, + model=model, + default_tools=self.default_tools, + python_tool=self.python_tool, + default_functions=self.default_functions, + stream=True, + ) + pred = [chunk for chunk in pred] + label = test_case["stream_label"] + check_streaming_response(pred, label) + except AssertionError: + print(f"test case {test_case['test_aim']} failed") + raise + finally: + if self.process: + kill_child_process(self.process.pid) + + def test_vllm_server(self): + for model in self.served_models: + for grammar_sample in [False, True]: + self.process = popen_launch_server( + backend="vllm", + model=model, + base_url=self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + context_length=4096, + grammar_sampling=grammar_sample, + ) + self._evaluate_test_cases(model) + + def test_sgl_server(self): + for model in self.served_models: + self.process = popen_launch_server( + backend="sglang", + model=model, + base_url=self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + context_length=4096, + grammar_sampling=False, + ) + self._evaluate_test_cases(model) diff --git a/tests/test_sgl_server.py b/tests/test_sgl_server.py deleted file mode 100644 index a63e011..0000000 --- a/tests/test_sgl_server.py +++ /dev/null @@ -1,477 +0,0 @@ -import json -import subprocess -import time -import unittest -from typing import Dict, List, Optional - -import psutil -import requests -from openai import OpenAI -from openai.types.chat.chat_completion_message import FunctionCall -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) -from rich import print - -DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 - - -def popen_launch_sgl_server( - model: str, - base_url: str, - timeout: float, - context_length: int, - grammar_sampling: bool, - env: Optional[dict] = None, - return_stdout_stderr: bool = False, -): - _, host, port = base_url.split(":") - host = host[2:] - - command = [ - "python3", - "server_sglang.py", - "--model", - model, - "--host", - host, - "--port", - str(port), - "--context-length", - str(context_length), - ] - if grammar_sampling: - command += ["--enable-grammar-sampling"] - - if return_stdout_stderr: - process = subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - text=True, - ) - else: - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) - - start_time = time.time() - api_key = "test" - while time.time() - start_time < timeout: - try: - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {api_key}", - } - response = requests.get(f"{base_url}/health", headers=headers) - if response.status_code == 200: - return process - except requests.RequestException: - pass - time.sleep(5) - raise TimeoutError("Server failed to start within the timeout period.") - - -def kill_child_process(pid, including_parent=True, skip_pid=None): - """Kill the process and all its children process.""" - try: - parent = psutil.Process(pid) - except psutil.NoSuchProcess: - return - - children = parent.children(recursive=True) - for child in children: - if child.pid == skip_pid: - continue - try: - child.kill() - except psutil.NoSuchProcess: - pass - - if including_parent: - try: - parent.kill() - except psutil.NoSuchProcess: - pass - - -def call_openai_api( - test_case: Dict, - client: OpenAI, - model: str, - default_tools: List, - python_tool: Dict, - default_functions: List, - stream: bool = False, -): - if test_case["call_mode"] == "tools": - if test_case["code_interpreter"]: - if model.startswith("meetkai"): - tools = default_tools + [{"type": "code_interpreter"}] - else: - tools = default_tools + [python_tool] - else: - tools = default_tools - response = client.chat.completions.create( - model=model, - messages=test_case["messages"], - tools=tools, - tool_choice=test_case["choice"], - temperature=0.0, - stream=stream, - ) - else: - response = client.chat.completions.create( - model=model, - messages=test_case["messages"], - functions=default_functions, - function_call=test_case["choice"], - temperature=0.0, - stream=stream, - ) - return response - - -class TestSglServer(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.openai_model = "gpt-4o-mini-2024-07-18" - cls.base_url = "http://127.0.0.1:8000" - cls.default_functions = [ - { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - } - }, - "required": ["location"], - }, - }, - ] - cls.default_tools = [{"type": "function", "function": cls.default_functions[0]}] - cls.python_tool = { - "type": "function", - "function": { - "name": "python", - "description": "Generate Python code", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The Python code to execute", - } - }, - "required": ["code"], - }, - }, - } - cls.request_handling_test_cases = [ - { - "test_aim": 'Normal text gen with "auto"', - "messages": [{"role": "user", "content": "How are you?"}], - "call_mode": "tools", - "code_interpreter": False, - "choice": "auto", - }, - { - "test_aim": 'Single tool_calls with "auto"', - "messages": [ - {"role": "user", "content": "What is the weather in Istanbul?"} - ], - "call_mode": "tools", - "code_interpreter": False, - "choice": "auto", - }, - { - "test_aim": 'Single function_call with "auto"', - "messages": [ - {"role": "user", "content": "What is the weather in Istanbul?"} - ], - "call_mode": "functions", - "code_interpreter": False, - "choice": "auto", - }, - { - "test_aim": 'Parallel tool_calls with "auto"', - "messages": [ - { - "role": "user", - "content": "What is the weather in Istanbul and Singapore respectively?", - } - ], - "call_mode": "tools", - "code_interpreter": False, - "choice": "auto", - }, - { - "test_aim": 'Parallel function_calls with "auto"', - "messages": [ - { - "role": "user", - "content": "What is the weather in Istanbul and Singapore respectively?", - } - ], - "call_mode": "functions", - "code_interpreter": False, - "choice": "auto", - }, - { - "test_aim": 'Normal text gen + tool_calls with "auto"', - "messages": [ - { - "role": "user", - "content": "What is the weather in Istanbul? Answer this question: 'How are you?', before checking the weather.", - } - ], - "call_mode": "tools", - "code_interpreter": False, - "choice": "auto", - }, - { - "test_aim": 'Normal text gen with "none"', - "messages": [ - {"role": "user", "content": "What is the weather in Istanbul?"} - ], - "call_mode": "tools", - "code_interpreter": False, - "choice": "none", - }, - { - "test_aim": "tool_calls with tool_choice", - "messages": [{"role": "user", "content": "How are you?"}], - "call_mode": "tools", - "code_interpreter": False, - "choice": { - "type": "function", - "function": {"name": cls.default_functions[0]["name"]}, - }, - }, - { - "test_aim": "function_call with function_call", - "messages": [{"role": "user", "content": "How are you?"}], - "call_mode": "functions", - "code_interpreter": False, - "choice": {"name": cls.default_functions[0]["name"]}, - }, - { - "test_aim": 'parallel tool_calls with "required"', - "messages": [ - { - "role": "user", - "content": "What is the weather in Istanbul and Singapore respectively?", - } - ], - "call_mode": "tools", - "code_interpreter": False, - "choice": "required", - }, - { - "test_aim": 'code generation using "python" tool', - "messages": [ - { - "role": "user", - "content": "Use the Python tool to write a Python function that adds 2 integers.", - } - ], - "call_mode": "tools", - "code_interpreter": True, - "choice": "auto", - }, - { - "test_aim": 'Normal text generation (CoT) + code generation using "python" tool', - "messages": [ - { - "role": "user", - "content": "Write a Python function that adds 2 integers. Answer this question: 'How are you?', before using the python tool.", - } - ], - "call_mode": "tools", - "code_interpreter": True, - "choice": "auto", - }, - ] - cls.client = OpenAI() - for i, test_case in enumerate(cls.request_handling_test_cases): - response = call_openai_api( - test_case=test_case, - client=cls.client, - model=cls.openai_model, - default_tools=cls.default_tools, - python_tool=cls.python_tool, - default_functions=cls.default_functions, - ) - tool_calls = response.choices[0].message.tool_calls - if tool_calls and len(tool_calls) > 0: - for j in range(len(tool_calls)): - if tool_calls[j].function.name == "python": - response.choices[0].message.tool_calls[j].function.arguments = ( - json.loads(tool_calls[j].function.arguments)["code"] - ) - cls.request_handling_test_cases[i]["label"] = response - - response = call_openai_api( - test_case=test_case, - client=cls.client, - model=cls.openai_model, - default_tools=cls.default_tools, - python_tool=cls.python_tool, - default_functions=cls.default_functions, - stream=True, - ) - chunks = [chunk for chunk in response] - cls.request_handling_test_cases[i]["stream_label"] = chunks - - @classmethod - def tearDownClass(cls): - if hasattr(cls, "process") and cls.process: - kill_child_process(cls.process.pid) - - def __init__(self, *args, **kwargs): - super(TestSglServer, self).__init__(*args, **kwargs) - self.served_models = [ - # "meetkai/functionary-small-v2.4", - "meetkai/functionary-small-v2.5", - "meetkai/functionary-small-v3.1", - "meetkai/functionary-small-v3.2", - ] - - def _check_nonstreaming_response(self, pred, label): - # Check if both label.id and pred.id start with the same prefix - assert pred.id.startswith(label.id[: label.id.index("-")]) - # Check if objects are equal - assert pred.object == label.object - pred_content = pred.choices[0].message.content - label_content = label.choices[0].message.content - pred_tool_calls = pred.choices[0].message.tool_calls - label_tool_calls = label.choices[0].message.tool_calls - pred_fn_call = pred.choices[0].message.function_call - label_fn_call = label.choices[0].message.function_call - # Check if content is equal - assert (pred_content is None) == (label_content is None) - # Check if tool_calls are equal - assert (pred_tool_calls is None) == (label_tool_calls is None) - if label_tool_calls is not None: - assert len(pred_tool_calls) == len(label_tool_calls) - for pred_tool_call, label_tool_call in zip( - pred_tool_calls, label_tool_calls - ): - assert isinstance(pred_tool_call, ChatCompletionMessageToolCall) - assert pred_tool_call.id.startswith( - "call_" - ) and label_tool_call.id.startswith("call_") - assert pred_tool_call.type == label_tool_call.type - assert pred_tool_call.function.name == label_tool_call.function.name - assert pred_tool_call.function.arguments is not None - # Check if function_calls are equal - assert (pred_fn_call is None) == (label_fn_call is None) - if label_fn_call is not None: - assert isinstance(pred_fn_call, FunctionCall) - assert pred_fn_call.name == label_fn_call.name - assert pred_fn_call.arguments is not None - # Check finish_reason - assert pred.choices[0].finish_reason == label.choices[0].finish_reason - - def _check_streaming_response(self, pred, label): - tool_call_id = -1 - for i, chunk in enumerate(pred): - # Check if both label.id and pred.id start with the same prefix - assert chunk.id.startswith(label[0].id[: label[0].id.index("-")]) - # Check if objects are equal - assert chunk.object == label[0].object - # Check if the assistant turn is in the first chunk only - if i == 0: - assert chunk.choices[0].delta.role == "assistant" - else: - assert chunk.choices[0].delta.role is None - # Check if the finish_reason is in the last chunk only - if i == len(pred) - 1: - assert chunk.choices[0].finish_reason is not None - else: - assert chunk.choices[0].finish_reason is None - # Check if only one of content, function_call or tool_calls is not None - non_none_fields = [ - chunk.choices[0].delta.content is not None, - chunk.choices[0].delta.function_call is not None, - chunk.choices[0].delta.tool_calls is not None, - ] - if i == len(pred) - 1: - assert sum(non_none_fields) == 0 - else: - assert sum(non_none_fields) == 1 - # Check tool_calls - if chunk.choices[0].delta.tool_calls is not None: - call_type = chunk.choices[0].delta.tool_calls[0].type - name = chunk.choices[0].delta.tool_calls[0].function.name - args = chunk.choices[0].delta.tool_calls[0].function.arguments - # Check name, arguments, call_type and index - assert args is not None - if len(args) == 0: - assert name is not None - assert call_type == "function" - tool_call_id += 1 - assert chunk.choices[0].delta.tool_calls[0].index == tool_call_id - else: - assert name is None - assert call_type is None - # Check function_call - if chunk.choices[0].delta.function_call is not None: - name = chunk.choices[0].delta.function_call.name - args = chunk.choices[0].delta.function_call.arguments - assert args is not None - if len(args) == 0: - assert name is not None - else: - assert name is None - - def test_sgl_server(self): - for model in self.served_models: - self.process = popen_launch_sgl_server( - model=model, - base_url=self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - context_length=4096, - grammar_sampling=False, - ) - - self.client = OpenAI(base_url=f"{self.base_url}/v1", api_key="test") - try: - for test_case in self.request_handling_test_cases: - # v2.5 cannot generate this case - if ( - model == "meetkai/functionary-small-v2.5" - and test_case["test_aim"] - == 'Normal text generation (CoT) + code generation using "python" tool' - ): - continue - pred = call_openai_api( - test_case=test_case, - client=self.client, - model=model, - default_tools=self.default_tools, - python_tool=self.python_tool, - default_functions=self.default_functions, - ) - label = test_case["label"] - self._check_nonstreaming_response(pred, label) - pred = call_openai_api( - test_case=test_case, - client=self.client, - model=model, - default_tools=self.default_tools, - python_tool=self.python_tool, - default_functions=self.default_functions, - stream=True, - ) - pred = [chunk for chunk in pred] - label = test_case["stream_label"] - self._check_streaming_response(pred, label) - except AssertionError: - raise - finally: - if self.process: - kill_child_process(self.process.pid) From 2c987f74e7d54b5b0e4b441cacb11b75aca8efd9 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Mon, 21 Oct 2024 15:22:38 +0000 Subject: [PATCH 35/40] fixes --- .github/workflows/python-package.yml | 3 +- README.md | 53 +++++++++++--------- functionary/inference_utils.py | 74 ++++++++++++++++++++++++++++ functionary/sglang_inference.py | 33 +++++-------- functionary/vllm_inference.py | 68 ++----------------------- server_sglang.py | 70 +++++++++++++++----------- 6 files changed, 164 insertions(+), 137 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 1c9ffa4..711aacf 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -40,4 +40,5 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - pytest + pytest tests --ignore=tests/test_server.py +# Ignore test_server.py for now as it requires a GPU runner diff --git a/README.md b/README.md index 24687eb..85cd4e9 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,9 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m Changelog: (click to expand) - + [2024-08-11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html) + + [2024/10/21] New server powered by [SGLang](https://github.com/sgl-project/sglang)! + + [2024/08/21] We release [meetkai/functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) and [meetkai/functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2) + + [2024/08/11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html) + [2024/08/08] We release 128k-context length 70B-model: [meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) that are based on [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) + [2024/08/07] We release 2 128k-context length models that are based on [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct): + [meetkai/functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1): **using Meta's original prompt template** as described in: [User-defined Custom tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1#user-defined-custom-tool-calling) @@ -29,44 +31,52 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m -### Setup +## Getting Started -To install the required dependencies, run: +Functionary can be deployed using either our [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) or [SGLang](https://sglang.readthedocs.io/en/latest/install.html) servers. Choose either one depending on your preferences. +### Installation + +**vLLM** ```shell pip install -r requirements.txt ``` +**SGLang** +```shell +pip install -r requirements_sgl.txt +``` -Now you can start a blazing fast [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) server. -[requirements](https://docs.vllm.ai/en/latest/getting_started/installation.html#requirements) +### Running the server -**Small Model:** +#### Small Model + +**vLLM** +```shell +python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --max-model-len 8192 +``` +**SGLang** ```shell -python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --max-model-len 8192 +python3 server_sglang.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --context-length 8192 ``` -**Medium Model:** +#### Medium Model -Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size` +Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size` or `tp` (SGLang) +**vLLM** ```shell # vllm requires to run this first: https://github.com/vllm-project/vllm/issues/6152 export VLLM_WORKER_MULTIPROC_METHOD=spawn -python server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 +python server_vllm.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --max-model-len 8192 --tensor-parallel-size 2 ``` - -
- SGLang - +**SGLang** ```shell -python server_sglang.py --model-path meetkai/functionary-medium-v3.2 --port 8000 --host 0.0.0.0 --tp 8 +python server_sglang.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2 ``` -
- -**Grammar Sampling** +### Grammar Sampling (Only in vLLM) We also offer our own function-calling grammar sampling feature which constrains the LLM's generation to always follow the prompt template, and ensures 100% accuracy for function name. The parameters are generated using the efficient [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), which ensures that the parameters follow the schema of the tool called. To enable grammar sampling, run the vLLM server with the command-line argument --enable-grammar-sampling: @@ -74,12 +84,10 @@ We also offer our own function-calling grammar sampling feature which constrains python3 server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 --enable-grammar-sampling ``` -Note: -- Grammar Sampling support is applicable only for the V2 and V3.0 models. There is no such support for V1 and V3.1 models. -- Our vLLM server supports the `tool_choice="required"` feature in OpenAI Chat Completion API exclusively **only when grammar sampling is enabled**. +**Note:** Grammar Sampling support is applicable only for the V2, V3.0, V3.2 models. There is no such support for V1 and V3.1 models. -**Text-Generation-Inference** +### Text-Generation-Inference (TGI) We also provide a service that performs inference on Functionary models using [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) (TGI). Follow these steps to get started: @@ -208,6 +216,7 @@ print(response.text) ## Models Available | Model | Description | VRAM FP16 | |:-------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------|:------| +| [functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2) | 128k context, code interpreter, using **our own prompt template** | 160GB | | [functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.2-GGUF) | 128k context, code interpreter, using **our own prompt template** | 24GB | | [functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 160GB | | [functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 24GB | diff --git a/functionary/inference_utils.py b/functionary/inference_utils.py index c7db22f..72c2fd6 100644 --- a/functionary/inference_utils.py +++ b/functionary/inference_utils.py @@ -1,9 +1,22 @@ +from http import HTTPStatus +from typing import Optional + import torch +from fastapi.responses import JSONResponse +from pydantic import BaseModel from transformers import StoppingCriteria, StoppingCriteriaList from functionary.prompt_template.prompt_utils import enforce_tool_choice +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + class StopWordsCriteria(StoppingCriteria): def __init__(self, stops=[]): StoppingCriteria.__init__(self) @@ -35,3 +48,64 @@ def analyze_tools_and_tool_choice(request): tool_func_choice = "none" return tools_or_functions, tool_func_choice + + +def create_error_response( + status_code: HTTPStatus, message: str, param: Optional[str] +) -> JSONResponse: + return JSONResponse( + ErrorResponse( + message=message, + type="invalid_request_error", + param=param, + code=status_code.value, + ).dict(), + status_code=status_code.value, + ) + + +async def check_all_errors(request, served_model) -> Optional[JSONResponse]: + if request.model not in served_model: + return create_error_response( + status_code=HTTPStatus.NOT_FOUND, + message=f"The model `{request.model}` does not exist.", + param=None, + ) + if request.tools and request.functions: + return create_error_response( + status_code=HTTPStatus.BAD_REQUEST, + message="'functions' and 'tools' cannot both be provided. 'functions' are deprecated; use the 'tools' parameter instead.", + param=None, + ) + if isinstance(request.function_call, str) and request.function_call not in [ + "none", + "auto", + ]: + return create_error_response( + status_code=HTTPStatus.BAD_REQUEST, + message=f"Invalid value: '{request.function_call}'. Supported values are: 'none' and 'auto'.", + param="function_call", + ) + if isinstance(request.tool_choice, str) and request.tool_choice not in [ + "none", + "auto", + "required", + ]: + return create_error_response( + status_code=HTTPStatus.BAD_REQUEST, + message=f"Invalid value: '{request.tool_choice}'. Supported values are: 'none', 'auto', and 'required'.", + param="tool_choice", + ) + if request.functions is None and request.function_call is not None: + return create_error_response( + status_code=HTTPStatus.BAD_REQUEST, + message=f"Invalid value for 'function_call': 'function_call' is only allowed when 'functions' are specified.", + param="function_call", + ) + if request.tools is None and request.tool_choice is not None: + return create_error_response( + status_code=HTTPStatus.BAD_REQUEST, + message=f"Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.", + param="tool_choice", + ) + return diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index a3cd03d..0c65789 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -38,7 +38,11 @@ from transformers import AutoTokenizer from functionary.inference_stream import generate_openai_format_from_stream_async -from functionary.inference_utils import analyze_tools_and_tool_choice +from functionary.inference_utils import ( + analyze_tools_and_tool_choice, + check_all_errors, + create_error_response, +) from functionary.openai_types import ( ChatCompletionChunk, ChatCompletionRequest, @@ -79,25 +83,6 @@ class ChatCompletionParams: grammar_sampling: bool -def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -): - error = ErrorResponse(message=message, type=err_type, code=status_code.value) - return JSONResponse(content=error.model_dump(), status_code=error.code) - - -def create_streaming_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -) -> str: - error = ErrorResponse(message=message, type=err_type, code=status_code.value) - json_str = json.dumps({"error": error.model_dump()}) - return json_str - - def convert_tool_calls_to_function_call( functions: Optional[List[Function]], chat_message: Dict ) -> Dict: @@ -507,7 +492,7 @@ async def v1_chat_generate_completion( params.adapted_request, params.raw_request ).__anext__() except ValueError as e: - return None, create_error_response(str(e)) + return None, create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return ret["text"], None @@ -581,6 +566,7 @@ async def v1_chat_completions( tokenizer_manager: Optional[TokenizerManager], srt_backend: Optional[Runtime], raw_request: Request, + served_model: List[str], ): """ Handle chat completions for v1 of the API. @@ -615,6 +601,11 @@ async def v1_chat_completions( prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer) tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request) + # Check for errors + error_check_ret = await check_all_errors(request, served_model) + if error_check_ret is not None: + return error_check_ret + # Generate the adapted request adapted_request, request = v1_chat_generate_request( request, tokenizer, tools_or_functions, tool_func_choice, return_text=False diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index f0d9502..2dff613 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -5,14 +5,17 @@ from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm.entrypoints.openai.protocol import ErrorResponse from vllm.inputs import TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from functionary.inference_stream import generate_openai_format_from_stream_async -from functionary.inference_utils import analyze_tools_and_tool_choice +from functionary.inference_utils import ( + analyze_tools_and_tool_choice, + check_all_errors, + create_error_response, +) from functionary.openai_types import ( ChatCompletionChunk, ChatCompletionRequest, @@ -33,67 +36,6 @@ ) -def create_error_response( - status_code: HTTPStatus, message: str, param: Optional[str] -) -> JSONResponse: - return JSONResponse( - ErrorResponse( - message=message, - type="invalid_request_error", - param=param, - code=status_code.value, - ).dict(), - status_code=status_code.value, - ) - - -async def check_all_errors(request, served_model) -> Optional[JSONResponse]: - if request.model not in served_model: - return create_error_response( - status_code=HTTPStatus.NOT_FOUND, - message=f"The model `{request.model}` does not exist.", - param=None, - ) - if request.tools and request.functions: - return create_error_response( - status_code=HTTPStatus.BAD_REQUEST, - message="'functions' and 'tools' cannot both be provided. 'functions' are deprecated; use the 'tools' parameter instead.", - param=None, - ) - if isinstance(request.function_call, str) and request.function_call not in [ - "none", - "auto", - ]: - return create_error_response( - status_code=HTTPStatus.BAD_REQUEST, - message=f"Invalid value: '{request.function_call}'. Supported values are: 'none' and 'auto'.", - param="function_call", - ) - if isinstance(request.tool_choice, str) and request.tool_choice not in [ - "none", - "auto", - "required", - ]: - return create_error_response( - status_code=HTTPStatus.BAD_REQUEST, - message=f"Invalid value: '{request.tool_choice}'. Supported values are: 'none', 'auto', and 'required'.", - param="tool_choice", - ) - if request.functions is None and request.function_call is not None: - return create_error_response( - status_code=HTTPStatus.BAD_REQUEST, - message=f"Invalid value for 'function_call': 'function_call' is only allowed when 'functions' are specified.", - param="function_call", - ) - if request.tools is None and request.tool_choice is not None: - return create_error_response( - status_code=HTTPStatus.BAD_REQUEST, - message=f"Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.", - param="tool_choice", - ) - return - - async def check_length(request, input_ids, model_config): if hasattr(model_config.hf_config, "max_sequence_length"): context_len = model_config.hf_config.max_sequence_length diff --git a/server_sglang.py b/server_sglang.py index 29bac1e..66dfc67 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -68,6 +68,7 @@ app = FastAPI() tokenizer_manager = None +served_model = [] app.add_middleware( CORSMiddleware, @@ -163,18 +164,20 @@ async def stream_results(): async def openai_v1_chat_completions(raw_request: Request): global tokenizer_manager, backend - if not args.grammar_sampling: - backend = None - return await v1_chat_completions(tokenizer_manager, backend, raw_request) + # if not args.grammar_sampling: + # backend = None + return await v1_chat_completions(tokenizer_manager, None, raw_request, served_model) @app.get("/v1/models") def available_models(): """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + if isinstance(served_model, list): + for model in served_model: + model_cards.append(ModelCard(id=model, root=model)) + else: + model_cards.append(ModelCard(id=served_model, root=served_model)) return ModelList(data=model_cards) @@ -357,31 +360,38 @@ def __init__( default=None, help="enable detailed request input/output logging by providing logfile", ) - parser.add_argument( - "--enable-grammar-sampling", - dest="grammar_sampling", - action="store_true", - default=False, - help="enable grammar sampling for function names", - ) + # parser.add_argument( + # "--enable-grammar-sampling", + # dest="grammar_sampling", + # action="store_true", + # default=False, + # help="enable grammar sampling for function names", + # ) ServerArgs.add_cli_args(parser) args = parser.parse_args() + + served_model = [args.model_path] + if args.served_model_name is not None: + served_model += args.served_model_name + server_args = ServerArgs.from_cli_args(args) - if args.grammar_sampling: - backend = FunctionaryRuntime(**vars(server_args)) - sgl.set_default_backend( - sgl.RuntimeEndpoint( - f"http://{backend.server_args.host}:{backend.server_args.port}" - ) - ) - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - else: - launch_server(server_args) + launch_server(server_args) + + # if args.grammar_sampling: + # backend = FunctionaryRuntime(**vars(server_args)) + # sgl.set_default_backend( + # sgl.RuntimeEndpoint( + # f"http://{backend.server_args.host}:{backend.server_args.port}" + # ) + # ) + # uvicorn.run( + # app, + # host=server_args.host, + # port=server_args.port, + # log_level=server_args.log_level_http or server_args.log_level, + # timeout_keep_alive=5, + # loop="uvloop", + # ) + # else: + # launch_server(server_args) From 54fa17657548d2a125656718791a8c796869918d Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Wed, 23 Oct 2024 05:04:19 +0000 Subject: [PATCH 36/40] upgrade sglang to v0.3.4.post1 --- requirements_sgl.txt | 2 +- server_sglang.py | 87 ++++++++++++++++++++++++++++---------------- 2 files changed, 57 insertions(+), 32 deletions(-) diff --git a/requirements_sgl.txt b/requirements_sgl.txt index 6bffbd5..8ed060b 100644 --- a/requirements_sgl.txt +++ b/requirements_sgl.txt @@ -1,4 +1,4 @@ jsonref~=1.1.0 -sglang[all]==0.3.3 +sglang[all]==0.3.4.post1 --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ flashinfer==0.1.6 diff --git a/server_sglang.py b/server_sglang.py index 66dfc67..8ed3e75 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -22,14 +22,12 @@ import logging import multiprocessing as mp import os -import socket -import sys import threading import time from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import AsyncIterator, Dict, List, Optional, Union -import requests +import orjson # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -39,8 +37,12 @@ import uvloop from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from uvicorn.config import LOGGING_CONFIG from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.scheduler import run_scheduler_process @@ -132,14 +134,18 @@ async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: - async def stream_results(): + async def stream_results() -> AsyncIterator[bytes]: try: async for out in tokenizer_manager.generate_request(obj, request): - yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" except ValueError as e: out = {"error": {"message": str(e)}} - yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + yield b"data: [DONE]\n\n" return StreamingResponse( stream_results(), @@ -151,7 +157,7 @@ async def stream_results(): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return JSONResponse( + return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST ) @@ -169,7 +175,7 @@ async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, None, raw_request, served_model) -@app.get("/v1/models") +@app.get("/v1/models", response_class=ORJSONResponse) def available_models(): """Show available models.""" model_cards = [] @@ -202,30 +208,40 @@ def launch_engine(server_args: ServerArgs): server_args.model_path, server_args.tokenizer_path ) - # Launch tensor parallel scheduler processes - scheduler_procs = [] - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + scheduler_procs = [] + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + + if server_args.node_rank >= 1: + # For other nodes, they do not need to run tokenizer or detokenizer, + # so they can just wait here. + while True: + pass + else: + # Launch the data parallel controller reader, writer = mp.Pipe(duplex=False) - gpu_id = tp_rank % tp_size_per_node + scheduler_pipe_readers = [reader] proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, writer), + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), ) proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - - if server_args.node_rank >= 1: - # For other nodes, they do not need to run tokenizer or detokenizer, - # so they can just wait here. - while True: - pass # Launch detokenizer process detoken_proc = mp.Process( @@ -286,6 +302,14 @@ def launch_server( try: # Listen for HTTP requests + LOGGING_CONFIG["formatters"]["default"][ + "fmt" + ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["access"][ + "fmt" + ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" uvicorn.run( app, host=server_args.host, @@ -327,6 +351,7 @@ def __init__( self.url = self.server_args.url() self.generate_url = self.url + "/generate" + # NOTE: We store pid instead of proc to fix some issues during __delete__ self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) From 8a87a6270b50e642d5def90bb622649601f74016 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Fri, 25 Oct 2024 11:17:04 +0000 Subject: [PATCH 37/40] change to pyproject.toml --- .github/workflows/python-package.yml | 2 +- README.md | 4 +-- pyproject.toml | 42 ++++++++++++++++++++++++++++ requirements.txt | 15 ---------- requirements_sgl.txt | 4 --- 5 files changed, 45 insertions(+), 22 deletions(-) create mode 100644 pyproject.toml delete mode 100644 requirements.txt delete mode 100644 requirements_sgl.txt diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 711aacf..125d568 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -29,7 +29,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + if [ -f pyproject.toml ]; then pip install -e .[vllm]; fi - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/README.md b/README.md index 85cd4e9..83278d4 100644 --- a/README.md +++ b/README.md @@ -39,11 +39,11 @@ Functionary can be deployed using either our [vLLM](https://vllm.readthedocs.io/ **vLLM** ```shell -pip install -r requirements.txt +pip install -e .[vllm] ``` **SGLang** ```shell -pip install -r requirements_sgl.txt +pip install -e .[sglang] --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ ``` ### Running the server diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2d03dda --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "functionary" +version = "0.0.1" +description = "Chat language model that can use tools and interpret the results" +requires-python = ">=3.9" + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["functionary"] + +[project.optional-dependencies] +vllm = [ + "transformers==4.43.3", + "accelerate~=0.21.0", + "sentencepiece~=0.1.99", + "fastapi~=0.111.0", + "uvicorn~=0.23.1", + "pydantic~=2.6.0", + "scipy~=1.11.1", + "jsonref~=1.1.0", + "requests~=2.31.0", + "PyYAML~=6.0.1", + "protobuf==3.20.0", + "tokenizers==0.19.1", + "vllm==0.5.4; sys_platform != 'darwin'", + "json_source_map==1.0.5", + "jinja2==3.1.4", +] +sglang = [ + "jsonref~=1.1.0", + "python-multipart==0.0.12", + "orjson==3.10.10", + "sglang[all]==0.3.4.post1", + "flashinfer==0.1.6", +] + +[project.urls] +homepage = "https://github.com/meetkai/functionary" +bugtracker = "https://github.com/meetkai/functionary/issues" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 4164362..0000000 --- a/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -transformers==4.43.3 -accelerate~=0.21.0 -sentencepiece~=0.1.99 -fastapi~=0.111.0 -uvicorn~=0.23.1 -pydantic~=2.6.0 -scipy~=1.11.1 -jsonref~=1.1.0 -requests~=2.31.0 -PyYAML~=6.0.1 -protobuf==3.20.0 -tokenizers==0.19.1 -vllm==0.5.4; sys_platform != "darwin" -json_source_map==1.0.5 -jinja2==3.1.4 diff --git a/requirements_sgl.txt b/requirements_sgl.txt deleted file mode 100644 index 8ed060b..0000000 --- a/requirements_sgl.txt +++ /dev/null @@ -1,4 +0,0 @@ -jsonref~=1.1.0 -sglang[all]==0.3.4.post1 ---find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ -flashinfer==0.1.6 From 0944d7c8a77273c8cb2b66fd6d38d3d267678a1c Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 5 Nov 2024 11:10:24 +0000 Subject: [PATCH 38/40] update pyproject.toml --- pyproject.toml | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2d03dda..4f5e561 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,21 +13,10 @@ packages = ["functionary"] [project.optional-dependencies] vllm = [ - "transformers==4.43.3", - "accelerate~=0.21.0", - "sentencepiece~=0.1.99", - "fastapi~=0.111.0", - "uvicorn~=0.23.1", - "pydantic~=2.6.0", - "scipy~=1.11.1", + "vllm==0.6.3.post1; sys_platform != 'darwin'", "jsonref~=1.1.0", - "requests~=2.31.0", - "PyYAML~=6.0.1", - "protobuf==3.20.0", - "tokenizers==0.19.1", - "vllm==0.5.4; sys_platform != 'darwin'", "json_source_map==1.0.5", - "jinja2==3.1.4", + "PyYAML~=6.0.1", ] sglang = [ "jsonref~=1.1.0", From 0c25249c52b1566d550618aa4828354907e7bf8e Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Tue, 5 Nov 2024 13:59:00 +0000 Subject: [PATCH 39/40] refactor tool_call to function_call; vllm health endpoint --- functionary/inference_utils.py | 35 ++++++++++++++++++++++++++++- functionary/sglang_inference.py | 39 ++++++--------------------------- functionary/vllm_inference.py | 34 +++++++++------------------- server_vllm.py | 7 ++++-- 4 files changed, 56 insertions(+), 59 deletions(-) diff --git a/functionary/inference_utils.py b/functionary/inference_utils.py index 8855223..89b5c35 100644 --- a/functionary/inference_utils.py +++ b/functionary/inference_utils.py @@ -1,6 +1,6 @@ from copy import deepcopy from http import HTTPStatus -from typing import Optional +from typing import Dict, List, Optional import jsonref import torch @@ -8,6 +8,7 @@ from pydantic import BaseModel from transformers import StoppingCriteria, StoppingCriteriaList +from functionary.openai_types import Function from functionary.prompt_template.prompt_utils import enforce_tool_choice @@ -128,3 +129,35 @@ def resolve_json_refs(tools_or_functions): ) return tools + + +def convert_tool_calls_to_function_call( + functions: Optional[List[Function]], chat_message: Dict +) -> Dict: + if "delta" not in chat_message: # Non-streaming + if ( + functions + and len(functions) > 0 + and "tool_calls" in chat_message + and chat_message["tool_calls"] is not None + and len(chat_message["tool_calls"]) > 0 + ): + chat_message["function_call"] = { + "name": chat_message["tool_calls"][0]["function"]["name"], + "arguments": chat_message["tool_calls"][0]["function"]["arguments"], + } + chat_message["tool_calls"] = None + else: # Streaming + if ( + functions + and len(functions) > 0 + and "tool_calls" in chat_message["delta"] + and chat_message["delta"]["tool_calls"] + and len(chat_message["delta"]["tool_calls"]) > 0 + ): + chat_message["delta"]["function_call"] = chat_message["delta"][ + "tool_calls" + ][0]["function"] + chat_message["delta"]["tool_calls"] = None + + return chat_message diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 0c65789..52fce26 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -41,6 +41,7 @@ from functionary.inference_utils import ( analyze_tools_and_tool_choice, check_all_errors, + convert_tool_calls_to_function_call, create_error_response, ) from functionary.openai_types import ( @@ -83,25 +84,6 @@ class ChatCompletionParams: grammar_sampling: bool -def convert_tool_calls_to_function_call( - functions: Optional[List[Function]], chat_message: Dict -) -> Dict: - if ( - functions - and len(functions) > 0 - and "tool_calls" in chat_message - and chat_message["tool_calls"] is not None - and len(chat_message["tool_calls"]) > 0 - ): - chat_message["function_call"] = { - "name": chat_message["tool_calls"][0]["function"]["name"], - "arguments": chat_message["tool_calls"][0]["function"]["arguments"], - } - chat_message["tool_calls"] = None - - return chat_message - - def v1_chat_generate_request( request: ChatCompletionRequest, tokenizer: AutoTokenizer, @@ -382,19 +364,12 @@ async def completion_stream_generator(params: ChatCompletionParams): params.tools_or_functions, ): # Convert tool_calls to function_call if request.functions is provided - if ( - params.request.functions - and len(params.request.functions) > 0 - and "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - tool_name = response["delta"]["tool_calls"][0]["function"]["name"] - tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] - response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ - "function" - ] - response["delta"]["tool_calls"] = None + response = convert_tool_calls_to_function_call( + functions=params.request.functions, chat_message=response + ) + if response["delta"]["function_call"]: + tool_name = response["delta"]["function_call"]["name"] + tool_args = response["delta"]["function_call"]["arguments"] if tool_name and len(tool_name) > 0 and tool_args == "": tool_call_count += 1 diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index 9989620..e5e1725 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -14,6 +14,7 @@ from functionary.inference_utils import ( analyze_tools_and_tool_choice, check_all_errors, + convert_tool_calls_to_function_call, create_error_response, ) from functionary.openai_types import ( @@ -193,19 +194,12 @@ async def completion_stream_generator( ): # Convert tool_calls to function_call if request.functions is provided - if ( - functions - and len(functions) > 0 - and "tool_calls" in response["delta"] - and response["delta"]["tool_calls"] - and len(response["delta"]["tool_calls"]) > 0 - ): - tool_name = response["delta"]["tool_calls"][0]["function"]["name"] - tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"] - response["delta"]["function_call"] = response["delta"]["tool_calls"][0][ - "function" - ] - response["delta"]["tool_calls"] = None + response = convert_tool_calls_to_function_call( + functions=request.functions, chat_message=response + ) + if response["delta"]["function_call"]: + tool_name = response["delta"]["function_call"]["name"] + tool_args = response["delta"]["function_call"]["arguments"] if tool_name and len(tool_name) > 0 and tool_args == "": tool_call_count += 1 # Return finish_reason after the first tool_call is streamed if functions is provided @@ -277,17 +271,9 @@ async def completion_stream_generator( ) # parse_generated_content(text_response) # Convert tool_calls to function_call if request.functions is provided - if ( - request.functions - and "tool_calls" in chat_mess - and chat_mess["tool_calls"] is not None - and len(chat_mess["tool_calls"]) > 0 - ): - chat_mess["function_call"] = { - "name": chat_mess["tool_calls"][0]["function"]["name"], - "arguments": chat_mess["tool_calls"][0]["function"]["arguments"], - } - chat_mess["tool_calls"] = None + chat_mess = convert_tool_calls_to_function_call( + functions=request.functions, chat_message=chat_mess + ) # Postprocess finish reason if tool_func_choice is None or tool_func_choice in ["auto", "required"]: diff --git a/server_vllm.py b/server_vllm.py index 53f394d..41e390e 100644 --- a/server_vllm.py +++ b/server_vllm.py @@ -27,8 +27,9 @@ import vllm.entrypoints.openai.api_server as vllm_api_server from fastapi import Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import Response from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.api_server import health, mount_metrics +from vllm.entrypoints.openai.api_server import mount_metrics from vllm.entrypoints.openai.protocol import ModelCard, ModelList, ModelPermission from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import get_tokenizer @@ -51,7 +52,9 @@ @app.get("/health") async def _health(): """Health check.""" - return await health() + # vLLM's OpenAI server's health check is too heavy and also requires + # creating engine_client here, so we just return 200 here. + return Response(status_code=200) @app.get("/v1/models") From 035cc72c81235b876c3f2b7e1c70e06feccba8b5 Mon Sep 17 00:00:00 2001 From: jeffreymeetkai <104876655+jeffreymeetkai@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:37:20 +0800 Subject: [PATCH 40/40] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 83278d4..9af5182 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 - ``` **SGLang** ```shell -python3 server_sglang.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --context-length 8192 +python3 server_sglang.py --model-path "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --context-length 8192 ``` #### Medium Model @@ -72,7 +72,7 @@ python server_vllm.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 - ``` **SGLang** ```shell -python server_sglang.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2 +python server_sglang.py --model-path "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2 ```