Skip to content

Commit

Permalink
[router] cache-aware load-balancing router v1 (#2114)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Nov 23, 2024
1 parent ad47749 commit cbedd1d
Show file tree
Hide file tree
Showing 17 changed files with 1,959 additions and 598 deletions.
53 changes: 42 additions & 11 deletions benchmark/multi_turn_chat/long_prompt_multi_turn.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import itertools
import json
import os
import random
import string
import threading
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import Union

from tqdm import tqdm

import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenize
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text

random.seed(42)


def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values())
Expand All @@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
return ret


def get_cache_path(args):
# Create cache directory under ~/.cache/sglang
cache_dir = Path.home() / ".cache" / "sglang"

# Create a unique cache filename based on the arguments that affect generation
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
return cache_dir / cache_key


def gen_arguments(args, tokenizer):
multi_qas = [
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
for _ in range(args.num_qa)
]
for i in range(args.num_qa):
cache_path = get_cache_path(args)

# Try to load from cache first
if cache_path.exists():
print(f"Loading cached arguments from {cache_path}")
with open(cache_path, "r") as f:
return json.load(f)

print("Generating new arguments...")
# First progress bar for system prompts
multi_qas = []
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
multi_qas.append(
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
)

# Nested progress bars for QA pairs
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
qas = multi_qas[i]["qas"]
for j in range(args.turns):
qas.append(
Expand All @@ -38,14 +63,21 @@ def gen_arguments(args, tokenizer):
"new_tokens": args.len_a,
}
)

# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w") as f:
json.dump(multi_qas, f)
print(f"Cached arguments saved to {cache_path}")

return multi_qas


@sgl.function
def multi_turns(s, system_prompt, qas):
s += system_prompt

for qa in qas:
for i, qa in enumerate(qas):
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)

Expand All @@ -62,7 +94,7 @@ def main(args):
multi_qas,
temperature=0,
backend=backend,
num_threads=args.parallel,
num_threads="auto",
progress_bar=True,
)
latency = time.time() - tic
Expand All @@ -75,7 +107,6 @@ def main(args):
value = {
"task": "multi_turn_system_prompt_chat",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
total_input_tokens = 0
total_output_tokens = 0

for group_idx in range(num_groups):
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
for prompt_idx in range(prompts_per_group):
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/test/few_shot_gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def run_eval(args):
# Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))

# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
if args.data_path is None:
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
else:
filename = args.data_path

lines = list(read_jsonl(filename))

# Construct prompts
Expand Down
24 changes: 23 additions & 1 deletion rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ futures-util = "0.3"
serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] }
dashmap = "6.1.0"
http = "1.1.0"

[profile.release]
lto = "thin"
Expand Down
3 changes: 3 additions & 0 deletions rust/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ pip install <path-to-wheel>
#### Option B: Development Mode

For development purposes, you can install the package in editable mode:

Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.

```bash
pip install -e .
```
Expand Down
10 changes: 0 additions & 10 deletions rust/demo.py

This file was deleted.

156 changes: 0 additions & 156 deletions rust/dp_demo.py

This file was deleted.

Loading

0 comments on commit cbedd1d

Please sign in to comment.