Skip to content

Commit

Permalink
Improve benchmark scripts (#1672)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 15, 2024
1 parent 4a292f6 commit 175afed
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 18 deletions.
33 changes: 22 additions & 11 deletions python/sglang/bench_server_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Usage:
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
"""

import argparse
Expand All @@ -15,7 +17,7 @@
import multiprocessing
import os
import time
from typing import Tuple
from typing import Optional, Tuple

import numpy as np
import requests
Expand All @@ -32,6 +34,8 @@ class BenchArgs:
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
result_filename: str = "result.jsonl"
base_url: str = ""
skip_warmup: bool = False

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand All @@ -48,6 +52,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down Expand Up @@ -139,17 +145,21 @@ def run_one_case(


def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
proc, base_url = launch_server_process(server_args)
if bench_args.base_url:
proc, base_url = None, bench_args.base_url
else:
proc, base_url = launch_server_process(server_args)

# warmup
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
run_name="",
result_filename="",
)
if not bench_args.skip_warmup:
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
run_name="",
result_filename="",
)

# benchmark
try:
Expand All @@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
bench_args.result_filename,
)
finally:
kill_child_process(proc.pid)
if proc:
kill_child_process(proc.pid)

print(f"\nResults are saved to {bench_args.result_filename}")

Expand Down
106 changes: 99 additions & 7 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,85 @@ async def async_request_openai_completions(
return output


async def async_request_sglang_generate(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
prompt = request_func_input.prompt

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": not args.disable_ignore_eos,
},
"stream": not args.disable_stream,
**request_func_input.extra_request_body,
}
headers = {}

output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
# print(chunk_bytes)

chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)

# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)

most_recent_timestamp = timestamp
generated_text = data["text"]

output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))

if pbar:
pbar.update(1)
return output


async def async_request_gserver(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
Expand Down Expand Up @@ -264,7 +343,9 @@ def get_tokenizer(


ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions,
"sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
Expand Down Expand Up @@ -387,6 +468,8 @@ def sample_sharegpt_requests(
continue
filtered_dataset.append((prompt, prompt_len, output_len))

print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
return filtered_dataset


Expand Down Expand Up @@ -784,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
if args.port is None:
args.port = {
"sglang": 30000,
"sglang-native": 30000,
"sglang-oai": 30000,
"lmdeploy": 23333,
"vllm": 8000,
"trt": 8000,
"gserver": 9988,
}.get(args.backend, 30000)

api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
model_url = (
f"{args.base_url}/v1/models"
if args.base_url
else f"http://{args.host}:{args.port}/v1/models"
)

if args.backend == "trt":
if args.backend in ["sglang", "sglang-native"]:
api_url = (
f"{args.base_url}/generate"
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
elif args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url
Expand Down

0 comments on commit 175afed

Please sign in to comment.