Skip to content

Commit

Permalink
Fix mixed chunked prefill (#1850)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 31, 2024
1 parent a7a0a68 commit f7102fb
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
8 changes: 5 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,11 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch.filter_batch()
if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
else:
new_batch.decoding_reqs = None
Expand Down
88 changes: 68 additions & 20 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import Callable, List, Optional
Expand Down Expand Up @@ -656,11 +657,12 @@ def read_output(output_lines):
time.sleep(0.1)


def run_mmlu_test(
def run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache:
Expand Down Expand Up @@ -690,21 +692,8 @@ def run_mmlu_test(
t = threading.Thread(target=read_output, args=(output_lines,))
t.start()

# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)

try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
# Run the workload
workload_func(base_url, model)

# Clean up everything
kill_child_process(process.pid, include_self=True)
Expand All @@ -727,4 +716,63 @@ def run_mmlu_test(
has_leak = True

assert has_new_server
# assert not has_leak
assert not has_leak


def run_mmlu_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)

try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass

run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)


def run_mulit_request_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):

def workload_func(base_url, model):
def run_one(_):
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""

response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
)
ret = response.json()

with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4))))

run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
7 changes: 7 additions & 0 deletions test/srt/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DEFAULT_MODEL_NAME_FOR_TEST,
run_bench_serving,
run_mmlu_test,
run_mulit_request_test,
)


Expand Down Expand Up @@ -39,6 +40,12 @@ def test_no_chunked_prefill_without_radix_cache(self):

assert res["completed"] == 10

def test_mixed_chunked_prefill_multi_requests(self):
run_mulit_request_test(
enable_mixed_chunk=True,
chunked_prefill_size=2048,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit f7102fb

Please sign in to comment.