diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8d7ccd3547e..8684fa5ce0d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -736,6 +736,10 @@ def prepare_for_decode(self): self.input_ids = self.output_ids self.seq_lens.add_(1) self.output_ids = None + if self.sampling_info.penalizer_orchestrator: + self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + self.input_ids + ) # Alloc mem bs = len(self.reqs) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e3fea477788..6bc93ea40d3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -20,6 +20,7 @@ import os import time import warnings +from collections import deque from types import SimpleNamespace from typing import List, Optional, Union @@ -192,9 +193,20 @@ def __init__( self.tree_cache_metrics = {"total": 0, "hit": 0} self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) + if self.server_args.enable_overlap_schedule: + + def cache_finished_req(req): + free_delta = int(self.running_batch and req in self.cur_batch.reqs) + self.tree_cache.cache_finished_req(req, free_delta=free_delta) + + else: + cache_finished_req = self.tree_cache.cache_finished_req + self.cache_finished_req = cache_finished_req + # Init running status self.waiting_queue: List[Req] = [] self.running_batch: Optional[ScheduleBatch] = None + self.cur_batch: Optional[ScheduleBatch] = None self.decode_forward_ct = 0 self.stream_interval = server_args.stream_interval self.num_generated_tokens = 0 @@ -279,6 +291,32 @@ def event_loop_normal(self): self.last_batch = batch + @torch.inference_mode() + def event_loop_overlap(self): + result_queue = deque() + + self.last_batch = None + self.running_batch = None + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + + batch = self.get_next_batch_to_run() + self.cur_batch = batch + if batch: + result = self.run_batch(batch) + result_queue.append((batch.copy(), result)) + + if self.last_batch: + tmp_batch, tmp_result = result_queue.popleft() + self.process_batch_result(tmp_batch, tmp_result) + elif batch is None: + self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio + + self.last_batch = batch + def recv_requests(self): if self.tp_rank == 0: recv_reqs = [] @@ -705,11 +743,6 @@ def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: logits_output, next_token_ids = result - if batch.sampling_info.penalizer_orchestrator: - batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - next_token_ids - ) - if batch.return_logprob: # Move logprobs to cpu if logits_output.next_token_logprobs is not None: @@ -742,7 +775,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): req.check_finished() if req.finished(): - self.tree_cache.cache_finished_req(req) + self.cache_finished_req(req) elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) @@ -771,7 +804,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): req.check_finished() if req.finished(): - self.tree_cache.cache_finished_req(req) + self.cache_finished_req(req) else: self.tree_cache.cache_unfinished_req(req) @@ -779,10 +812,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids = result - if batch.sampling_info.penalizer_orchestrator: - batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - next_token_ids - ) self.num_generated_tokens += len(batch.reqs) # Move logprobs to cpu @@ -796,6 +825,9 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + if self.server_args.enable_overlap_schedule and req.finished(): + continue + req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -806,7 +838,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): ) if req.finished(): - self.tree_cache.cache_finished_req(req) + self.cache_finished_req(req) if req.return_logprob: req.output_token_logprobs.append( @@ -1027,7 +1059,7 @@ def abort_request(self, recv_req: AbortReq): for req in self.running_batch.reqs: if req.rid == recv_req.rid and not req.finished(): req.finished_reason = FINISH_ABORT() - self.tree_cache.cache_finished_req(req) + self.cache_finished_req(req) break def update_weights(self, recv_req: UpdateWeightReqInput): @@ -1072,7 +1104,10 @@ def run_scheduler_process( try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank) pipe_writer.send("ready") - scheduler.event_loop_normal() + if server_args.enable_overlap_schedule: + scheduler.event_loop_overlap() + else: + scheduler.event_loop_normal() except Exception: msg = get_exception_traceback() logger.error(msg) diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index e7e48ecee49..e13a2075ab9 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -38,12 +38,16 @@ def match_prefix(self, rid: int, key: List[int]): max_prefix_len = len(key) return entry.value[:max_prefix_len], entry - def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): + def cache_finished_req( + self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0 + ): if token_ids is None: - token_ids = (req.origin_input_ids + req.output_ids)[:-1] + token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1 + else: + token_id_len = len(token_ids) kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, : token_id_len + free_delta ] self.req_to_token_pool.free(req.req_pool_idx) self.token_to_kv_pool.free(kv_indices) @@ -53,10 +57,12 @@ def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): if token_ids is None: - token_ids = req.fill_ids + token_id_len = len(req.fill_ids) + else: + token_id_len = len(token_ids) kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :token_id_len ] if req.rid not in self.entries: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 7690d18b791..68bcb5b0b45 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -97,22 +97,38 @@ def insert(self, key: List, value=None): value = [x for x in key] return self._insert_helper(self.root_node, key, value) - def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): + def cache_finished_req( + self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0 + ): """Cache request when it finishes.""" + if self.disable: + if token_ids is None: + token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1 + else: + token_ids_len = len(token_ids) + + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : token_ids_len + free_delta + ] + self.token_to_kv_pool.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return + if token_ids is None: token_ids = (req.origin_input_ids + req.output_ids)[:-1] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] - if self.disable: - self.token_to_kv_pool.free(kv_indices) - self.req_to_token_pool.free(req.req_pool_idx) - return - # Radix Cache takes one ref in memory pool new_prefix_len = self.insert(token_ids, kv_indices.clone()) self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) + if free_delta: + self.token_to_kv_pool.free( + self.req_to_token_pool.req_to_token[ + req.req_pool_idx, len(token_ids) : len(token_ids) + 1 + ] + ) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7111c9333d2..e398ab4b091 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): kill_child_process(pid, including_parent=False) return + # print(f"{res.json()=}") + logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: pipe_finish_writer.send("ready") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3184b4e7d94..eafdede82cc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -113,6 +113,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False disable_mla: bool = False disable_penalizer: bool = False + enable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False max_torch_compile_bs: int = 32 @@ -572,6 +573,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable the logit penalizer (e.g., frequency and repetition penalty).", ) + parser.add_argument( + "--enable-overlap-schedule", + action="store_true", + help="Overlap the CPU scheduler with GPU model worker. Experimental feature.", + ) parser.add_argument( "--enable-mixed-chunk", action="store_true", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ac2a8cf7f87..c5aecb73938 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str): def configure_logger(server_args, prefix: str = ""): format = f"[%(asctime)s{prefix}] %(message)s" + # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s" logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), format=format, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 10775e1004a..3f8a1fecb1f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ "test_json_constrained.py", "test_large_max_new_tokens.py", "test_openai_server.py", + "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", "test_retract_decode.py", "test_server_args.py", diff --git a/test/srt/test_overlap_schedule.py b/test/srt/test_overlap_schedule.py new file mode 100644 index 00000000000..76d17aff2a3 --- /dev/null +++ b/test/srt/test_overlap_schedule.py @@ -0,0 +1,65 @@ +""" +Usage: +SGLANG_IS_IN_CI=true python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill +SGLANG_IS_IN_CI=true python3 test_overlap_schedule.py +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestOverlapSchedule(unittest.TestCase): + def run_mmlu(self, disable_radix_cache, chunked_prefill_size=32): + other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] + if disable_radix_cache: + other_args += ["--disable-radix-cache"] + other_args += ["--enable-overlap-schedule"] + + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + try: + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + finally: + kill_child_process(process.pid) + + def test_no_radix_attention_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=32) + + def test_no_radix_attention_no_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=-1) + + def test_radix_attention_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=32) + + def test_radix_attention_no_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=-1) + + +if __name__ == "__main__": + unittest.main() + # @unittest.skip("did not support")