From 8f6e3084503079d4fb145fdfd0b90795d3d142bf Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 02:26:22 +0000 Subject: [PATCH 01/10] fix --- python/sglang/srt/managers/scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e9bf7be8ee8..5b12479514a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -568,9 +568,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: has_inflight = self.current_inflight_req is not None if has_inflight: - self.current_inflight_req.init_next_round_input( - None if prefix_computed else self.tree_cache - ) + self.current_inflight_req.init_next_round_input() self.current_inflight_req = adder.add_inflight_req( self.current_inflight_req ) From d75cd3b87625097d61fdaafccba0339ed68b3381 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 02:38:44 +0000 Subject: [PATCH 02/10] fix fsm_cache when skip tokenizer init --- python/sglang/srt/managers/scheduler.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5b12479514a..961efe1d544 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -225,16 +225,15 @@ def __init__( ) # Init the FSM cache for constrained generation - if not server_args.skip_tokenizer_init: - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, - ) + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation From 994c08ca50f45eb4dbaee47968a887032f5985d0 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 03:00:13 +0000 Subject: [PATCH 03/10] fix new token ratio inconsistent --- python/sglang/global_config.py | 12 +++++++++++- python/sglang/srt/managers/scheduler.py | 16 ++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 5e7290edc51..37938d31578 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -15,7 +15,7 @@ def __init__(self): # Runtime constants: New generation token ratio estimation self.init_new_token_ratio = 0.7 - self.base_min_new_token_ratio = 0.1 + self.min_new_token_ratio = 0.1 self.new_token_ratio_decay = 0.001 # Runtime constants: others @@ -32,5 +32,15 @@ def __init__(self): self.enable_precache_with_tracing = True self.enable_parallel_encoding = True + def adjust_new_token_ratio(self, schedule_conservativeness=1): + assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness" + global_config.min_new_token_ratio = min( + global_config.min_new_token_ratio * schedule_conservativeness, + 1.0, + ) + global_config.init_new_token_ratio = max( + global_config.init_new_token_ratio, global_config.min_new_token_ratio + ) + global_config = GlobalConfig() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 961efe1d544..990e5fc8053 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -237,16 +237,8 @@ def __init__( self.jump_forward_cache = JumpForwardCache() # Init new token estimation - assert ( - server_args.schedule_conservativeness >= 0 - ), "Invalid schedule_conservativeness" - self.min_new_token_ratio = min( - global_config.base_min_new_token_ratio - * server_args.schedule_conservativeness, - 1.0, - ) - self.new_token_ratio = self.min_new_token_ratio - self.new_token_ratio_decay = global_config.new_token_ratio_decay + global_config.adjust_new_token_ratio(server_args.schedule_conservativeness) + self.new_token_ratio = global_config.init_new_token_ratio self.batch_is_full = False # Init profiler @@ -706,8 +698,8 @@ def update_running_batch(self): self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( - self.new_token_ratio - self.new_token_ratio_decay, - self.min_new_token_ratio, + self.new_token_ratio - global_config.new_token_ratio_decay, + global_config.min_new_token_ratio, ) # Check for jump-forward From 1379c594d030ce38321e45f0184341da83ed9b50 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 05:33:53 +0000 Subject: [PATCH 04/10] inflight -> being_chunked --- python/sglang/srt/managers/schedule_batch.py | 7 +-- python/sglang/srt/managers/schedule_policy.py | 25 ++++++--- python/sglang/srt/managers/scheduler.py | 55 ++++++++----------- 3 files changed, 43 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fcd06d8cc9c..39fc1e558f1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -222,7 +222,7 @@ def __init__( self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None - self.is_inflight_req = 0 + self.is_being_chunked = False # Logprobs (arguments) self.return_logprob = False @@ -906,15 +906,14 @@ def prepare_for_decode(self, enable_overlap: bool = False): def filter_batch( self, - current_inflight_req: Optional[Req] = None, + being_chunked_req: Optional[Req] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) - if not self.reqs[i].finished() - and self.reqs[i] is not current_inflight_req + if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 45c9be37a6e..a5362ff7cad 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -136,7 +136,7 @@ def __init__( self.req_states = None self.can_run_list = [] - self.new_inflight_req = None + self.new_chunked_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 @@ -176,7 +176,7 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len - def add_inflight_req(self, req: Req): + def add_being_chunked_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -192,8 +192,13 @@ def add_inflight_req(self, req: Req): ), ) - # Return if chunked prefill not finished - return req if truncated else None + if truncated: + # Continue to chunk the request + assert req.is_being_chunked + self.new_chunked_req = req + else: + # Release the being chunked status + req.is_being_chunked = False @contextmanager def _lock_node(self, last_node: TreeNode): @@ -262,11 +267,14 @@ def add_req_state(r, insert_sort=False): ) else: # Chunked prefill + assert self.new_chunked_req is None + trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len + req.is_being_chunked = True req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) - self.new_inflight_req = req + self.new_chunked_req = req self._prefill_one_req(0, trunc_len, 0) return self.budget_state() @@ -305,15 +313,18 @@ def add_one_req(self, req: Req): min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), ) else: - # Chunked prefill trunc_len = self.rem_chunk_tokens if trunc_len == 0: return AddReqResult.OTHER + # Chunked prefill + assert self.new_chunked_req is None + req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] + req.is_being_chunked = True self.can_run_list.append(req) - self.new_inflight_req = req + self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 990e5fc8053..fe6ac67eba5 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -219,7 +219,7 @@ def __init__( # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None + self.being_chunked_req = None self.is_mixed_chunk = ( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) @@ -490,20 +490,18 @@ def check_memory(self): ) exit(1) if crash_on_warning else None - def get_next_batch_to_run(self): + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch if ( self.last_batch and not self.last_batch.forward_mode.is_decode() and not self.last_batch.is_empty() ): - if self.current_inflight_req: - self.last_batch.filter_batch( - current_inflight_req=self.current_inflight_req - ) - self.tree_cache.cache_unfinished_req(self.current_inflight_req) - # Inflight request keeps its rid but will get a new req_pool_idx. - self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) + if self.being_chunked_req: + self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) + self.tree_cache.cache_unfinished_req(self.being_chunked_req) + # Being chunked request keeps its rid but will get a new req_pool_idx. + self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.batch_is_full = False if not self.last_batch.is_empty(): if self.running_batch is None: @@ -534,7 +532,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.current_inflight_req is None: + ) and self.being_chunked_req is None: return None running_bs = len(self.running_batch.reqs) if self.running_batch else 0 @@ -557,13 +555,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: num_mixed_running, ) - has_inflight = self.current_inflight_req is not None - if has_inflight: - self.current_inflight_req.init_next_round_input() - self.current_inflight_req = adder.add_inflight_req( - self.current_inflight_req - ) - if self.lora_paths: lora_set = ( set([req.lora_path for req in self.running_batch.reqs]) @@ -571,6 +562,12 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: else set([]) ) + # NOTE: if there is request being chunked, we always add it first + has_being_chunked = self.being_chunked_req is not None + if has_being_chunked: + self.being_chunked_req.init_next_round_input() + adder.add_being_chunked_req(self.being_chunked_req) + # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( @@ -604,12 +601,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: x for x in self.waiting_queue if x not in set(can_run_list) ] - if adder.new_inflight_req is not None: - assert self.current_inflight_req is None - self.current_inflight_req = adder.new_inflight_req - - if self.current_inflight_req: - self.current_inflight_req.is_inflight_req += 1 + # Update new round being chunked request + self.being_chunked_req = adder.new_chunked_req # Print stats if self.tp_rank == 0: @@ -638,7 +631,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" ) else: logger.info( @@ -649,7 +642,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" ) # Create a new batch @@ -772,10 +765,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 - else: - # Inflight reqs' prefill is not finished + if not req.is_being_chunked: + # Being chunked reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) req.check_finished() @@ -801,10 +792,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Check finish conditions for i, req in enumerate(batch.reqs): req.embedding = embeddings[i] - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 - else: - # Inflight reqs' prefill is not finished + if not req.is_being_chunked: + # Being chunked reqs' prefill is not finished # dummy output token for embedding models req.output_ids.append(0) req.check_finished() From 86d976a045cc8f8576fcff1d1566860229f78897 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 05:45:49 +0000 Subject: [PATCH 05/10] resolve comment --- python/sglang/global_config.py | 8 +++++--- python/sglang/srt/managers/scheduler.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 37938d31578..60c11a8098b 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -34,13 +34,15 @@ def __init__(self): def adjust_new_token_ratio(self, schedule_conservativeness=1): assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness" - global_config.min_new_token_ratio = min( + min_new_token_ratio = min( global_config.min_new_token_ratio * schedule_conservativeness, 1.0, ) - global_config.init_new_token_ratio = max( - global_config.init_new_token_ratio, global_config.min_new_token_ratio + init_new_token_ratio = max( + global_config.init_new_token_ratio, min_new_token_ratio ) + return min_new_token_ratio, init_new_token_ratio + global_config = GlobalConfig() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fe6ac67eba5..76e3be07338 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -237,8 +237,10 @@ def __init__( self.jump_forward_cache = JumpForwardCache() # Init new token estimation - global_config.adjust_new_token_ratio(server_args.schedule_conservativeness) - self.new_token_ratio = global_config.init_new_token_ratio + self.min_new_token_ratio, self.init_new_token_ratio = ( + global_config.adjust_new_token_ratio(server_args.schedule_conservativeness) + ) + self.new_token_ratio = self.init_new_token_ratio self.batch_is_full = False # Init profiler @@ -285,7 +287,7 @@ def event_loop_normal(self): self.process_batch_result(batch, result) else: self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio + self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch @@ -312,7 +314,7 @@ def event_loop_overlap(self): self.process_batch_result(tmp_batch, tmp_result) elif batch is None: self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio + self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch @@ -565,6 +567,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # NOTE: if there is request being chunked, we always add it first has_being_chunked = self.being_chunked_req is not None if has_being_chunked: + # NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result self.being_chunked_req.init_next_round_input() adder.add_being_chunked_req(self.being_chunked_req) @@ -692,7 +695,7 @@ def update_running_batch(self): else: self.new_token_ratio = max( self.new_token_ratio - global_config.new_token_ratio_decay, - global_config.min_new_token_ratio, + self.min_new_token_ratio, ) # Check for jump-forward From 2173783609fbc05ded036363680ff475fd7b9293 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 06:41:08 +0000 Subject: [PATCH 06/10] add unit test --- test/srt/run_suite.py | 1 + test/srt/test_radix_attention.py | 112 +++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 test/srt/test_radix_attention.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f8a1fecb1f..2b1be4ed76c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -5,6 +5,7 @@ suites = { "minimal": [ + "test_radix_attention.py", "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py new file mode 100644 index 00000000000..292a7b454d9 --- /dev/null +++ b/test/srt/test_radix_attention.py @@ -0,0 +1,112 @@ +import os +import random +import unittest + +import requests + +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + kill_child_process, + popen_launch_server, +) + + +def gen_radix_tree(num_nodes=400, chunk_len=256): + num0 = num_nodes // 2 + num1 = num_nodes - num0 + nodes = [{"input_ids": [37] * 117, "decode_len": 217}] + for _ in range(num0): + parent = random.choice(nodes) + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + while num1 > 0: + num_branch = random.randint(1, min(num1, 10)) + parent = random.choice(nodes) + for _ in range(num_branch): + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + num1 -= num_branch + + random.shuffle(nodes) + return nodes + + +def run_test(base_url, nodes): + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes + ], + } + + res = requests.post(base_url + "/generate", json=data) + assert res.status_code == 200 + + +class TestRadixCacheFCFS(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "fcfs", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_radix_attention(self): + nodes = gen_radix_tree() + run_test(self.base_url, nodes) + + +class TestRadixCacheLPM(TestRadixCacheFCFS): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "lpm", + ], + ) + + +if __name__ == "__main__": + os.environ["SGLANG_TEST_RETRACT"] = "true" + unittest.main() From 3272bfccc8d2cfc886736f4d1116fc166058e781 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 25 Oct 2024 07:33:03 +0000 Subject: [PATCH 07/10] fix mem usage --- python/sglang/test/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 20fc9d52da9..baea2fa5208 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -658,6 +658,7 @@ def run_mmlu_test( chunked_prefill_size=32, ): other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] + other_args += ["--mem-fraction-static", "0.85"] if disable_radix_cache: other_args += ["--disable-radix-cache"] if enable_mixed_chunk: From b192faa046f3fa490e90bed6db0226051d2c6704 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 30 Oct 2024 04:59:26 +0000 Subject: [PATCH 08/10] update --- python/sglang/global_config.py | 13 ------------- python/sglang/srt/managers/scheduler.py | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 4c3ec485840..5b563ebf96a 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -38,17 +38,4 @@ def __init__(self): self.enable_precache_with_tracing = True self.enable_parallel_encoding = True - def adjust_new_token_ratio(self, schedule_conservativeness=1): - assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness" - min_new_token_ratio = min( - global_config.min_new_token_ratio * schedule_conservativeness, - 1.0, - ) - init_new_token_ratio = max( - global_config.init_new_token_ratio, min_new_token_ratio - ) - - return min_new_token_ratio, init_new_token_ratio - - global_config = GlobalConfig() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 642dbcfcf31..80a65ca0e42 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -746,7 +746,7 @@ def update_running_batch(self): self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( - self.new_token_ratio - global_config.new_token_ratio_decay, + self.new_token_ratio - self.new_token_ratio_decay, self.min_new_token_ratio, ) From 5d71528965e94247b798994a03675cd4c66e69c0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 31 Oct 2024 13:45:58 -0700 Subject: [PATCH 09/10] Make is_being_chunked a counter --- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/managers/schedule_policy.py | 6 +++--- python/sglang/test/test_utils.py | 1 - test/srt/run_suite.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 750408d56e5..f99f3377e52 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -221,7 +221,7 @@ def __init__( self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None - self.is_being_chunked = False + self.is_being_chunked = 0 # Logprobs (arguments) self.return_logprob = False diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index b705fb1b7ba..43889289582 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -200,7 +200,7 @@ def add_being_chunked_req(self, req: Req): self.new_chunked_req = req else: # Release the being chunked status - req.is_being_chunked = False + req.is_being_chunked -= 1 @contextmanager def _lock_node(self, last_node: TreeNode): @@ -273,7 +273,7 @@ def add_req_state(r, insert_sort=False): trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len - req.is_being_chunked = True + req.is_being_chunked += 1 req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) self.new_chunked_req = req @@ -327,7 +327,7 @@ def add_one_req(self, req: Req): req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] - req.is_being_chunked = True + req.is_being_chunked += 1 self.can_run_list.append(req) self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index e7ca7748772..8a486131f0c 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -665,7 +665,6 @@ def run_and_check_memory_leak( chunked_prefill_size, ): other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] - other_args += ["--mem-fraction-static", "0.85"] if disable_radix_cache: other_args += ["--disable-radix-cache"] if enable_mixed_chunk: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 6074da11e00..f7277f03dab 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -5,7 +5,6 @@ suites = { "minimal": [ - "test_radix_attention.py", "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", @@ -20,6 +19,7 @@ "test_openai_server.py", "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", + "test_radix_attention.py", "test_retract_decode.py", "test_server_args.py", "test_skip_tokenizer_init.py", From 7f86f05866aedbf51ed64874d92445b5aba4c810 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 31 Oct 2024 13:46:36 -0700 Subject: [PATCH 10/10] Fix style --- python/sglang/global_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 5b563ebf96a..d557e6a6e62 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -38,4 +38,5 @@ def __init__(self): self.enable_precache_with_tracing = True self.enable_parallel_encoding = True + global_config = GlobalConfig()