Skip to content

Commit

Permalink
Merge pull request #277 from MeetKai/sglang
Browse files Browse the repository at this point in the history
Implement SGLang Server
  • Loading branch information
jeffreymeetkai authored Nov 6, 2024
2 parents 0231886 + 035cc72 commit d47e4a6
Show file tree
Hide file tree
Showing 17 changed files with 2,168 additions and 179 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f pyproject.toml ]; then pip install -e .[vllm]; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -40,4 +40,5 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
pytest
pytest tests --ignore=tests/test_server.py
# Ignore test_server.py for now as it requires a GPU runner
50 changes: 34 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

<summary>Changelog: (click to expand)</summary>

+ [2024-08-11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/10/21] New server powered by [SGLang](https://github.com/sgl-project/sglang)!
+ [2024/08/21] We release [meetkai/functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) and [meetkai/functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2)
+ [2024/08/11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/08/08] We release 128k-context length 70B-model: [meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) that are based on [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
+ [2024/08/07] We release 2 128k-context length models that are based on [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct):
+ [meetkai/functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1): **using Meta's original prompt template** as described in: [User-defined Custom tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1#user-defined-custom-tool-calling)
Expand All @@ -29,48 +31,63 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

</details>

### Setup
## Getting Started

To install the required dependencies, run:
Functionary can be deployed using either our [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) or [SGLang](https://sglang.readthedocs.io/en/latest/install.html) servers. Choose either one depending on your preferences.

### Installation

**vLLM**
```shell
pip install -e .[vllm]
```
**SGLang**
```shell
pip install -r requirements.txt
pip install -e .[sglang] --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/
```

Now you can start a blazing fast [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) server.
[requirements](https://docs.vllm.ai/en/latest/getting_started/installation.html#requirements)
### Running the server

#### Small Model

**Small Model:**
**vLLM**
```shell
python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --max-model-len 8192
python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --max-model-len 8192
```
**SGLang**
```shell
python3 server_sglang.py --model-path "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --context-length 8192
```

**Medium Model:**
#### Medium Model

Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size`
Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size` or `tp` (SGLang)

**vLLM**
```shell
# vllm requires to run this first: https://github.com/vllm-project/vllm/issues/6152
export VLLM_WORKER_MULTIPROC_METHOD=spawn

python server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2
python server_vllm.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --max-model-len 8192 --tensor-parallel-size 2
```
**SGLang**
```shell
python server_sglang.py --model-path "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2
```


**Grammar Sampling**
### Grammar Sampling (Only in vLLM)

We also offer our own function-calling grammar sampling feature which constrains the LLM's generation to always follow the prompt template, and ensures 100% accuracy for function name. The parameters are generated using the efficient [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), which ensures that the parameters follow the schema of the tool called. To enable grammar sampling, run the vLLM server with the command-line argument <code>--enable-grammar-sampling</code>:

```shell
python3 server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 --enable-grammar-sampling
```

Note:
- Grammar Sampling support is applicable only for the V2 and V3.0 models. There is no such support for V1 and V3.1 models.
- Our vLLM server supports the `tool_choice="required"` feature in OpenAI Chat Completion API exclusively **only when grammar sampling is enabled**.
**Note:** Grammar Sampling support is applicable only for the V2, V3.0, V3.2 models. There is no such support for V1 and V3.1 models.


**Text-Generation-Inference**
### Text-Generation-Inference (TGI)

We also provide a service that performs inference on Functionary models using [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) (TGI). Follow these steps to get started:

Expand Down Expand Up @@ -199,6 +216,7 @@ print(response.text)
## Models Available
| Model | Description | VRAM FP16 |
|:-------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------|:------|
| [functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2) | 128k context, code interpreter, using **our own prompt template** | 160GB |
| [functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.2-GGUF) | 128k context, code interpreter, using **our own prompt template** | 24GB |
| [functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 160GB |
| [functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 24GB |
Expand Down
106 changes: 106 additions & 0 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from copy import deepcopy
from http import HTTPStatus
from typing import Dict, List, Optional

import jsonref
import torch
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList

from functionary.openai_types import Function
from functionary.prompt_template.prompt_utils import enforce_tool_choice


class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int


class StopWordsCriteria(StoppingCriteria):
def __init__(self, stops=[]):
StoppingCriteria.__init__(self)
Expand Down Expand Up @@ -40,6 +53,67 @@ def analyze_tools_and_tool_choice(request):
return tools_or_functions, tool_func_choice


def create_error_response(
status_code: HTTPStatus, message: str, param: Optional[str]
) -> JSONResponse:
return JSONResponse(
ErrorResponse(
message=message,
type="invalid_request_error",
param=param,
code=status_code.value,
).dict(),
status_code=status_code.value,
)


async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
if request.model not in served_model:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
param=None,
)
if request.tools and request.functions:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message="'functions' and 'tools' cannot both be provided. 'functions' are deprecated; use the 'tools' parameter instead.",
param=None,
)
if isinstance(request.function_call, str) and request.function_call not in [
"none",
"auto",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.function_call}'. Supported values are: 'none' and 'auto'.",
param="function_call",
)
if isinstance(request.tool_choice, str) and request.tool_choice not in [
"none",
"auto",
"required",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.tool_choice}'. Supported values are: 'none', 'auto', and 'required'.",
param="tool_choice",
)
if request.functions is None and request.function_call is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'function_call': 'function_call' is only allowed when 'functions' are specified.",
param="function_call",
)
if request.tools is None and request.tool_choice is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.",
param="tool_choice",
)
return


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
if tools:
Expand All @@ -55,3 +129,35 @@ def resolve_json_refs(tools_or_functions):
)

return tools


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
if "delta" not in chat_message: # Non-streaming
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message
and chat_message["tool_calls"] is not None
and len(chat_message["tool_calls"]) > 0
):
chat_message["function_call"] = {
"name": chat_message["tool_calls"][0]["function"]["name"],
"arguments": chat_message["tool_calls"][0]["function"]["arguments"],
}
chat_message["tool_calls"] = None
else: # Streaming
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message["delta"]
and chat_message["delta"]["tool_calls"]
and len(chat_message["delta"]["tool_calls"]) > 0
):
chat_message["delta"]["function_call"] = chat_message["delta"][
"tool_calls"
][0]["function"]
chat_message["delta"]["tool_calls"] = None

return chat_message
11 changes: 11 additions & 0 deletions functionary/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,22 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None

# Disable logprobs and top_logprobs currently first
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None

# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None
min_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)

# @validator("tool_choice", always=True)
# def validate_tool_choice(cls, value, values):
# if value is None:
Expand Down
2 changes: 1 addition & 1 deletion functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _update_gen_state_for_fn_call(self, gen_state: Dict, func_name: str):
gen_state["func_name"] = func_name
gen_state["func_index"] += 1
gen_state["call_id"] = prompt_utils.get_random_tool_call_id()
gen_state["first_time_func"] = True
gen_state["first_function_chunk"] = True

return gen_state

Expand Down
41 changes: 24 additions & 17 deletions functionary/prompt_template/llama31_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def initialize_fsm_gen_state(
"func_name": func_name,
"func_index": -1, # index of the tool in tool_calls
"call_id": None, # call_id of the current tool
"gen_empty_text": True, # if first_time we return an empty delta with role=assistant
"first_time_func": True,
"first_chunk": True,
"first_function_chunk": True,
"text_to_func_buffer": [],
"clear_buffer": False,
"add_code_interpreter": add_code_interpreter,
Expand Down Expand Up @@ -182,14 +182,14 @@ def stream_delta_text(
)

if gen_state["stage"] == "text-gen":
if gen_state["gen_empty_text"]:
if gen_state["first_chunk"]:
responses.append(
prompt_utils.get_text_delta_response("", True, finish_reason)
)
gen_state["gen_empty_text"] = False
gen_state["first_chunk"] = False
responses.append(
prompt_utils.get_text_delta_response(
gen_state["curr_text"], True, finish_reason
gen_state["curr_text"], False, finish_reason
)
)
text_in_buffer = "".join(gen_state["text_to_func_buffer"] + [delta_text])
Expand All @@ -201,32 +201,38 @@ def stream_delta_text(
delta_text_to_stream = gen_state["text_to_func_buffer"][0]
responses.append(
prompt_utils.get_text_delta_response(
delta_text_to_stream, True, finish_reason
delta_text_to_stream, False, finish_reason
)
)
gen_state["text_to_func_buffer"] = gen_state["text_to_func_buffer"][
1:
]
responses.append(
prompt_utils.get_text_delta_response(
delta_text, True, finish_reason
delta_text, False, finish_reason
)
)
else:
gen_state["text_to_func_buffer"].append(delta_text)
elif gen_state["stage"] == "parameter":
if gen_state["first_time_func"]:
gen_state["first_time_func"] = False
if gen_state["first_function_chunk"]:
responses.append(
prompt_utils.get_function_delta_response(
gen_state, "", True, False, finish_reason
gen_state, "", True, gen_state["first_chunk"], finish_reason
)
)
responses.append(
prompt_utils.get_function_delta_response(
gen_state, gen_state["curr_text"], False, False, finish_reason
gen_state["first_chunk"] = False
gen_state["first_function_chunk"] = False
if gen_state["curr_text"] != "":
responses.append(
prompt_utils.get_function_delta_response(
gen_state,
gen_state["curr_text"],
False,
False,
finish_reason,
)
)
)

if "</" in delta_text:
delta_args = delta_text.removesuffix("</")
Expand All @@ -251,11 +257,12 @@ def stream_delta_text(
)
)
elif gen_state["stage"] == "code-interpreter":
if gen_state["first_time_func"]:
gen_state["first_time_func"] = False
if gen_state["first_function_chunk"]:
first_function_response = prompt_utils.get_function_delta_response(
gen_state, "", True, False, finish_reason
gen_state, "", True, gen_state["first_chunk"], finish_reason
)
gen_state["first_chunk"] = False
gen_state["first_function_chunk"] = False
responses.append(first_function_response)
responses.append(
prompt_utils.get_function_delta_response(
Expand Down
Loading

0 comments on commit d47e4a6

Please sign in to comment.