Skip to content

Commit

Permalink
Add a new event loop (#1677)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 16, 2024
1 parent a5114b6 commit 9116b28
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 25 deletions.
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,10 @@ def prepare_for_decode(self):
self.input_ids = self.output_ids
self.seq_lens.add_(1)
self.output_ids = None
if self.sampling_info.penalizer_orchestrator:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.input_ids
)

# Alloc mem
bs = len(self.reqs)
Expand Down
63 changes: 49 additions & 14 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import time
import warnings
from collections import deque
from types import SimpleNamespace
from typing import List, Optional, Union

Expand Down Expand Up @@ -192,9 +193,20 @@ def __init__(
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)

if self.server_args.enable_overlap_schedule:

def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)

else:
cache_finished_req = self.tree_cache.cache_finished_req
self.cache_finished_req = cache_finished_req

# Init running status
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
Expand Down Expand Up @@ -279,6 +291,32 @@ def event_loop_normal(self):

self.last_batch = batch

@torch.inference_mode()
def event_loop_overlap(self):
result_queue = deque()

self.last_batch = None
self.running_batch = None

while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)

batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))

if self.last_batch:
tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio

self.last_batch = batch

def recv_requests(self):
if self.tp_rank == 0:
recv_reqs = []
Expand Down Expand Up @@ -705,11 +743,6 @@ def process_batch_result(self, batch: ScheduleBatch, result):
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)

if batch.return_logprob:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
Expand Down Expand Up @@ -742,7 +775,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
req.check_finished()

if req.finished():
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)

Expand Down Expand Up @@ -771,18 +804,14 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
req.check_finished()

if req.finished():
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)

self.stream_output(batch.reqs)

def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
self.num_generated_tokens += len(batch.reqs)

# Move logprobs to cpu
Expand All @@ -796,6 +825,9 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished():
continue

req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
Expand All @@ -806,7 +838,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
)

if req.finished():
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)

if req.return_logprob:
req.output_token_logprobs.append(
Expand Down Expand Up @@ -1027,7 +1059,7 @@ def abort_request(self, recv_req: AbortReq):
for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
break

def update_weights(self, recv_req: UpdateWeightReqInput):
Expand Down Expand Up @@ -1072,7 +1104,10 @@ def run_scheduler_process(
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
pipe_writer.send("ready")
scheduler.event_loop_normal()
if server_args.enable_overlap_schedule:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
Expand Down
16 changes: 11 additions & 5 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ def match_prefix(self, rid: int, key: List[int]):
max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry

def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
def cache_finished_req(
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_id_len = len(token_ids)

kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, : token_id_len + free_delta
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)
Expand All @@ -53,10 +57,12 @@ def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):

def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_ids = req.fill_ids
token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)

kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :token_id_len
]

if req.rid not in self.entries:
Expand Down
28 changes: 22 additions & 6 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,38 @@ def insert(self, key: List, value=None):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)

def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
def cache_finished_req(
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
"""Cache request when it finishes."""
if self.disable:
if token_ids is None:
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_ids_len = len(token_ids)

kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : token_ids_len + free_delta
]
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return

if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]

if self.disable:
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return

# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if free_delta:
self.token_to_kv_pool.free(
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
]
)

# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process(pid, including_parent=False)
return

# print(f"{res.json()=}")

logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ServerArgs:
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_penalizer: bool = False
enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
Expand Down Expand Up @@ -572,6 +573,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
)
parser.add_argument(
"--enable-overlap-schedule",
action="store_true",
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):

def configure_logger(server_args, prefix: str = ""):
format = f"[%(asctime)s{prefix}] %(message)s"
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format=format,
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"test_json_constrained.py",
"test_large_max_new_tokens.py",
"test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py",
"test_retract_decode.py",
"test_server_args.py",
Expand Down
65 changes: 65 additions & 0 deletions test/srt/test_overlap_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Usage:
SGLANG_IS_IN_CI=true python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill
SGLANG_IS_IN_CI=true python3 test_overlap_schedule.py
"""

import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestOverlapSchedule(unittest.TestCase):
def run_mmlu(self, disable_radix_cache, chunked_prefill_size=32):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache:
other_args += ["--disable-radix-cache"]
other_args += ["--enable-overlap-schedule"]

model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

try:
metrics = run_eval(args)
assert metrics["score"] >= 0.65
finally:
kill_child_process(process.pid)

def test_no_radix_attention_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=32)

def test_no_radix_attention_no_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=-1)

def test_radix_attention_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=32)

def test_radix_attention_no_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=-1)


if __name__ == "__main__":
unittest.main()
# @unittest.skip("did not support")

0 comments on commit 9116b28

Please sign in to comment.