Skip to content

Commit 88814d4

Browse files
[WIP] merge router into weight update
2 parents 8c8710e + 84a1698 commit 88814d4

33 files changed

+2529
-703
lines changed

.github/workflows/release-pypi-router.yml

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
name: Release SGLang Router to PyPI
44

55
on:
6+
push:
7+
branches:
8+
- main
9+
paths:
10+
- rust/pyproject.toml
611
workflow_dispatch:
712

813
jobs:

benchmark/multi_turn_chat/long_prompt_multi_turn.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
import itertools
22
import json
3+
import os
34
import random
45
import string
56
import threading
67
import time
78
from argparse import ArgumentParser
9+
from pathlib import Path
10+
from typing import Union
11+
12+
from tqdm import tqdm
813

914
import sglang as sgl
10-
from sglang.srt.hf_transformers_utils import get_tokenize
15+
from sglang.srt.hf_transformers_utils import get_tokenizer
1116
from sglang.test.test_utils import (
1217
add_common_sglang_args_and_parse,
1318
select_sglang_backend,
1419
)
1520
from sglang.utils import dump_state_text
1621

17-
random.seed(42)
18-
1922

2023
def gen_prompt(tokenizer, token_num):
2124
all_available_tokens = list(tokenizer.get_vocab().values())
@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
2427
return ret
2528

2629

30+
def get_cache_path(args):
31+
# Create cache directory under ~/.cache/sglang
32+
cache_dir = Path.home() / ".cache" / "sglang"
33+
34+
# Create a unique cache filename based on the arguments that affect generation
35+
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
36+
return cache_dir / cache_key
37+
38+
2739
def gen_arguments(args, tokenizer):
28-
multi_qas = [
29-
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
30-
for _ in range(args.num_qa)
31-
]
32-
for i in range(args.num_qa):
40+
cache_path = get_cache_path(args)
41+
42+
# Try to load from cache first
43+
if cache_path.exists():
44+
print(f"Loading cached arguments from {cache_path}")
45+
with open(cache_path, "r") as f:
46+
return json.load(f)
47+
48+
print("Generating new arguments...")
49+
# First progress bar for system prompts
50+
multi_qas = []
51+
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
52+
multi_qas.append(
53+
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
54+
)
55+
56+
# Nested progress bars for QA pairs
57+
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
3358
qas = multi_qas[i]["qas"]
3459
for j in range(args.turns):
3560
qas.append(
@@ -38,14 +63,21 @@ def gen_arguments(args, tokenizer):
3863
"new_tokens": args.len_a,
3964
}
4065
)
66+
67+
# Save to cache
68+
cache_path.parent.mkdir(parents=True, exist_ok=True)
69+
with open(cache_path, "w") as f:
70+
json.dump(multi_qas, f)
71+
print(f"Cached arguments saved to {cache_path}")
72+
4173
return multi_qas
4274

4375

4476
@sgl.function
4577
def multi_turns(s, system_prompt, qas):
4678
s += system_prompt
4779

48-
for qa in qas:
80+
for i, qa in enumerate(qas):
4981
s += qa["prompt"]
5082
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
5183

@@ -62,7 +94,7 @@ def main(args):
6294
multi_qas,
6395
temperature=0,
6496
backend=backend,
65-
num_threads=args.parallel,
97+
num_threads="auto",
6698
progress_bar=True,
6799
)
68100
latency = time.time() - tic
@@ -75,7 +107,6 @@ def main(args):
75107
value = {
76108
"task": "multi_turn_system_prompt_chat",
77109
"backend": args.backend,
78-
"num_gpus": 1,
79110
"latency": round(latency, 3),
80111
"num_requests": args.num_qa,
81112
"num_turns": args.turns,

python/pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
3131
# xpu is not enabled in public vllm and torch whl,
3232
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
3333
srt_xpu = ["sglang[runtime_common]"]
34+
#For Intel Gaudi(device : hpu) follow the installation guide
35+
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
36+
srt_hpu = ["sglang[runtime_common]"]
3437

3538
openai = ["openai>=1.0", "tiktoken"]
3639
anthropic = ["anthropic>=0.20.0"]
@@ -46,9 +49,11 @@ test = [
4649
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
4750
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
4851
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
52+
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
4953
dev = ["sglang[all]", "sglang[test]"]
5054
dev_hip = ["sglang[all_hip]", "sglang[test]"]
5155
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
56+
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
5257

5358
[project.urls]
5459
"Homepage" = "https://github.com/sgl-project/sglang"

python/sglang/bench_one_batch.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,7 @@ def correctness_test(
278278

279279

280280
def synchronize(device):
281-
if device == "cuda":
282-
torch.cuda.synchronize()
283-
elif device == "xpu":
284-
torch.xpu.synchronize()
281+
torch.get_device_module(device).synchronize()
285282

286283

287284
def latency_test_run_once(

python/sglang/bench_serving.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
727727
total_input_tokens = 0
728728
total_output_tokens = 0
729729

730-
for group_idx in range(num_groups):
730+
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
731731
system_prompt = system_prompts[group_idx]
732-
for prompt_idx in range(prompts_per_group):
732+
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
733733
question = questions[group_idx * prompts_per_group + prompt_idx]
734734
full_prompt = f"{system_prompt}\n\n{question}"
735735
prompt_len = len(tokenizer.encode(full_prompt))
@@ -859,6 +859,7 @@ async def benchmark(
859859
tokenizer: PreTrainedTokenizerBase,
860860
input_requests: List[Tuple[str, int, int]],
861861
request_rate: float,
862+
max_concurrency: Optional[int],
862863
disable_tqdm: bool,
863864
extra_request_body: Dict[str, Any],
864865
profile: bool,
@@ -868,6 +869,15 @@ async def benchmark(
868869
else:
869870
raise ValueError(f"Unknown backend: {backend}")
870871

872+
# From https://github.com/vllm-project/vllm/pull/9390
873+
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
874+
875+
async def limited_request_func(request_func_input, pbar):
876+
if semaphore is None:
877+
return await request_func(request_func_input=request_func_input, pbar=pbar)
878+
async with semaphore:
879+
return await request_func(request_func_input=request_func_input, pbar=pbar)
880+
871881
print("Starting initial single prompt test run...")
872882
test_prompt, test_prompt_len, test_output_len = input_requests[0]
873883
test_input = RequestFuncInput(
@@ -913,7 +923,7 @@ async def benchmark(
913923
)
914924
tasks.append(
915925
asyncio.create_task(
916-
request_func(request_func_input=request_func_input, pbar=pbar)
926+
limited_request_func(request_func_input=request_func_input, pbar=pbar)
917927
)
918928
)
919929
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
@@ -940,6 +950,12 @@ async def benchmark(
940950
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
941951
print("{:<40} {:<10}".format("Backend:", backend))
942952
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
953+
print(
954+
"{:<40} {:<10}".format(
955+
"Max reqeuest concurrency:",
956+
max_concurrency if max_concurrency else "not set",
957+
)
958+
)
943959
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
944960
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
945961
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
@@ -1003,6 +1019,7 @@ async def benchmark(
10031019
"backend": args.backend,
10041020
"dataset_name": args.dataset_name,
10051021
"request_rate": request_rate,
1022+
"max_concurrency": max_concurrency,
10061023
"total_input_tokens": metrics.total_input,
10071024
"total_output_tokens": metrics.total_output,
10081025
"total_output_tokens_retokenized": metrics.total_output_retokenized,
@@ -1090,6 +1107,10 @@ def run_benchmark(args_: argparse.Namespace):
10901107
global args
10911108
args = args_
10921109

1110+
# Set default value for max_concurrency if not present
1111+
if not hasattr(args, "max_concurrency"):
1112+
args.max_concurrency = None
1113+
10931114
# Set global environments
10941115
set_ulimit()
10951116
random.seed(args.seed)
@@ -1201,6 +1222,7 @@ def run_benchmark(args_: argparse.Namespace):
12011222
tokenizer=tokenizer,
12021223
input_requests=input_requests,
12031224
request_rate=args.request_rate,
1225+
max_concurrency=args.max_concurrency,
12041226
disable_tqdm=args.disable_tqdm,
12051227
extra_request_body=extra_request_body,
12061228
profile=args.profile,
@@ -1220,6 +1242,7 @@ def run_benchmark(args_: argparse.Namespace):
12201242
tokenizer=tokenizer,
12211243
input_requests=input_requests,
12221244
request_rate=rate,
1245+
max_concurrency=args.max_concurrency,
12231246
disable_tqdm=args.disable_tqdm,
12241247
extra_request_body=extra_request_body,
12251248
profile=args.profile,
@@ -1319,6 +1342,19 @@ def set_ulimit(target_soft_limit=65535):
13191342
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
13201343
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
13211344
)
1345+
parser.add_argument(
1346+
"--max-concurrency",
1347+
type=int,
1348+
default=None,
1349+
help="Maximum number of concurrent requests. This can be used "
1350+
"to help simulate an environment where a higher level component "
1351+
"is enforcing a maximum number of concurrent requests. While the "
1352+
"--request-rate argument controls the rate at which requests are "
1353+
"initiated, this argument will control how many are actually allowed "
1354+
"to execute at a time. This means that when used in combination, the "
1355+
"actual request rate may be lower than specified with --request-rate, "
1356+
"if the server is not processing requests fast enough to keep up.",
1357+
)
13221358
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
13231359
parser.add_argument(
13241360
"--multi",

0 commit comments

Comments
 (0)