Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak for chunked prefill 2 #1858

Merged
merged 6 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5
python3 run_suite.py --suite minimal --range-begin 0 --range-end 4

unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
Expand All @@ -67,7 +67,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
python3 run_suite.py --suite minimal --range-begin 4 --range-end 14

unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
Expand All @@ -84,7 +84,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 17 --range-end 20
python3 run_suite.py --suite minimal --range-begin 14 --range-end 20

unit-test-backend-part-4:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
Expand Down
1 change: 0 additions & 1 deletion docs/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Guide on Hyperparameter Tuning

## Achieving Peak Throughput

Achieving a large batch size is the most important thing for attaining high throughput.

When the server is running at full load, look for the following in the log:
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
self.is_inflight_req = 0
self.is_being_chunked = 0

# Logprobs (arguments)
self.return_logprob = False
Expand Down Expand Up @@ -888,15 +888,15 @@ def prepare_for_decode(self, enable_overlap: bool = False):

def filter_batch(
self,
current_inflight_req: Optional[Req] = None,
being_chunked_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
and self.reqs[i] is not being_chunked_req
]

if keep_indices is None or len(keep_indices) == 0:
Expand Down
38 changes: 18 additions & 20 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.being_chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
Expand Down Expand Up @@ -551,13 +551,13 @@ def get_next_batch_to_run(self):
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.current_inflight_req:
if self.being_chunked_req:
self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req
being_chunked_req=self.being_chunked_req
)
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch is None:
Expand Down Expand Up @@ -588,7 +588,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
) and self.being_chunked_req is None:
return None

running_bs = len(self.running_batch.reqs) if self.running_batch else 0
Expand All @@ -611,13 +611,11 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
num_mixed_running,
)

has_inflight = self.current_inflight_req is not None
has_inflight = self.being_chunked_req is not None
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
self.being_chunked_req.init_next_round_input()
self.being_chunked_req = adder.add_inflight_req(
self.being_chunked_req
)

if self.lora_paths:
Expand Down Expand Up @@ -661,11 +659,11 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
]

if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req
assert self.being_chunked_req is None
self.being_chunked_req = adder.new_inflight_req

if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
if self.being_chunked_req:
self.being_chunked_req.is_being_chunked += 1

# Print stats
if self.tp_rank == 0:
Expand Down Expand Up @@ -833,8 +831,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
logprob_pt = 0
for i, req in enumerate(batch.reqs):
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
if req.is_being_chunked > 0:
req.is_being_chunked -= 1
else:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
Expand All @@ -860,8 +858,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
if req.is_being_chunked > 0:
req.is_being_chunked -= 1
else:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
Expand Down
4 changes: 1 addition & 3 deletions scripts/killall_sglang.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Kill all SGLang processes and free the GPU memory.
"""
# Kill all SGLang processes and free the GPU memory.

kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
Expand Down
112 changes: 112 additions & 0 deletions test/srt/test_radix_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import random
import unittest

import requests

from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
kill_child_process,
popen_launch_server,
)


def gen_radix_tree(num_nodes=400, chunk_len=256):
num0 = num_nodes // 2
num1 = num_nodes - num0
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
for _ in range(num0):
parent = random.choice(nodes)
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)

while num1 > 0:
num_branch = random.randint(1, min(num1, 10))
parent = random.choice(nodes)
for _ in range(num_branch):
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)

num1 -= num_branch

random.shuffle(nodes)
return nodes


def run_test(base_url, nodes):
data = {
"input_ids": [node["input_ids"] for node in nodes],
"sampling_params": [
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
],
}

res = requests.post(base_url + "/generate", json=data)
assert res.status_code == 200


class TestRadixCacheFCFS(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"fcfs",
],
)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)

def test_radix_attention(self):
nodes = gen_radix_tree()
run_test(self.base_url, nodes)


class TestRadixCacheLPM(TestRadixCacheFCFS):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"lpm",
],
)


if __name__ == "__main__":
os.environ["SGLANG_TEST_RETRACT"] = "true"
unittest.main()
Loading