Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memroy leak caused by chunked prefill #1837

Closed
wants to merge 14 commits into from
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
self.is_inflight_req = 0
self.is_being_chunked = 0

# Logprobs (arguments)
self.return_logprob = False
Expand Down Expand Up @@ -888,15 +888,14 @@ def prepare_for_decode(self, enable_overlap: bool = False):

def filter_batch(
self,
current_inflight_req: Optional[Req] = None,
being_chunked_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
]

if keep_indices is None or len(keep_indices) == 0:
Expand Down
25 changes: 18 additions & 7 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(

self.req_states = None
self.can_run_list = []
self.new_inflight_req = None
self.new_chunked_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0

Expand Down Expand Up @@ -178,7 +178,7 @@ def _prefill_one_req(
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len

def add_inflight_req(self, req: Req):
def add_being_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
Expand All @@ -194,8 +194,13 @@ def add_inflight_req(self, req: Req):
),
)

# Return if chunked prefill not finished
return req if truncated else None
if truncated:
# Continue to chunk the request
assert req.is_being_chunked
self.new_chunked_req = req
else:
# Release the being chunked status
req.is_being_chunked -= 1

@contextmanager
def _lock_node(self, last_node: TreeNode):
Expand Down Expand Up @@ -264,11 +269,14 @@ def add_req_state(r, insert_sort=False):
)
else:
# Chunked prefill
assert self.new_chunked_req is None

trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.is_being_chunked += 1
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_inflight_req = req
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)

return self.budget_state()
Expand Down Expand Up @@ -310,15 +318,18 @@ def add_one_req(self, req: Req):
),
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER

# Chunked prefill
assert self.new_chunked_req is None

req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
req.is_being_chunked += 1
self.can_run_list.append(req)
self.new_inflight_req = req
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)

Expand Down
58 changes: 23 additions & 35 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.being_chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
Expand Down Expand Up @@ -544,20 +544,18 @@ def check_memory(self):
)
exit(1) if crash_on_warning else None

def get_next_batch_to_run(self):
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.current_inflight_req:
self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req
)
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
if self.being_chunked_req:
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch is None:
Expand Down Expand Up @@ -588,7 +586,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
) and self.being_chunked_req is None:
return None

running_bs = len(self.running_batch.reqs) if self.running_batch else 0
Expand All @@ -611,22 +609,20 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
num_mixed_running,
)

has_inflight = self.current_inflight_req is not None
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)

if self.lora_paths:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)

# NOTE: if there is request being chunked, we always add it first
has_being_chunked = self.being_chunked_req is not None
if has_being_chunked:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self.being_chunked_req.init_next_round_input()
adder.add_being_chunked_req(self.being_chunked_req)

# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
Expand Down Expand Up @@ -660,12 +656,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
x for x in self.waiting_queue if x not in set(can_run_list)
]

if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req

if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
# Update new round being chunked request
self.being_chunked_req = adder.new_chunked_req

# Print stats
if self.tp_rank == 0:
Expand Down Expand Up @@ -694,7 +686,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
)
else:
logger.info(
Expand All @@ -705,7 +697,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
)

# Create a new batch
Expand Down Expand Up @@ -833,10 +825,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
logprob_pt = 0
for i, req in enumerate(batch.reqs):
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
if not req.is_being_chunked:
# Being chunked reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
Expand All @@ -860,10 +850,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
if not req.is_being_chunked:
# Being chunked reqs' prefill is not finished
# dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
Expand Down
112 changes: 112 additions & 0 deletions test/srt/test_radix_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import random
import unittest

import requests

from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
kill_child_process,
popen_launch_server,
)


def gen_radix_tree(num_nodes=400, chunk_len=256):
num0 = num_nodes // 2
num1 = num_nodes - num0
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
for _ in range(num0):
parent = random.choice(nodes)
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)

while num1 > 0:
num_branch = random.randint(1, min(num1, 10))
parent = random.choice(nodes)
for _ in range(num_branch):
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)

num1 -= num_branch

random.shuffle(nodes)
return nodes


def run_test(base_url, nodes):
data = {
"input_ids": [node["input_ids"] for node in nodes],
"sampling_params": [
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
],
}

res = requests.post(base_url + "/generate", json=data)
assert res.status_code == 200


class TestRadixCacheFCFS(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"fcfs",
],
)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)

def test_radix_attention(self):
nodes = gen_radix_tree()
run_test(self.base_url, nodes)


class TestRadixCacheLPM(TestRadixCacheFCFS):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"lpm",
],
)


if __name__ == "__main__":
os.environ["SGLANG_TEST_RETRACT"] = "true"
unittest.main()
Loading