-
Notifications
You must be signed in to change notification settings - Fork 743
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix memory leak for chunked prefill 2 (#1858)
Co-authored-by: Liangsheng Yin <[email protected]>
- Loading branch information
1 parent
8ce202a
commit a2e0424
Showing
7 changed files
with
138 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,4 @@ | ||
""" | ||
Kill all SGLang processes and free the GPU memory. | ||
""" | ||
# Kill all SGLang processes and free the GPU memory. | ||
|
||
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') | ||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import os | ||
import random | ||
import unittest | ||
|
||
import requests | ||
|
||
from sglang.test.test_utils import ( | ||
DEFAULT_MODEL_NAME_FOR_TEST, | ||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
DEFAULT_URL_FOR_TEST, | ||
kill_child_process, | ||
popen_launch_server, | ||
) | ||
|
||
|
||
def gen_radix_tree(num_nodes=400, chunk_len=256): | ||
num0 = num_nodes // 2 | ||
num1 = num_nodes - num0 | ||
nodes = [{"input_ids": [37] * 117, "decode_len": 217}] | ||
for _ in range(num0): | ||
parent = random.choice(nodes) | ||
unique_len = random.randint(0, chunk_len) | ||
decode_len = random.randint(0, chunk_len) | ||
token_id = random.randint(0, 32000) | ||
child = { | ||
"input_ids": parent["input_ids"] + [token_id] * unique_len, | ||
"decode_len": decode_len, | ||
} | ||
nodes.append(child) | ||
|
||
while num1 > 0: | ||
num_branch = random.randint(1, min(num1, 10)) | ||
parent = random.choice(nodes) | ||
for _ in range(num_branch): | ||
unique_len = random.randint(0, chunk_len) | ||
decode_len = random.randint(0, chunk_len) | ||
token_id = random.randint(0, 32000) | ||
child = { | ||
"input_ids": parent["input_ids"] + [token_id] * unique_len, | ||
"decode_len": decode_len, | ||
} | ||
nodes.append(child) | ||
|
||
num1 -= num_branch | ||
|
||
random.shuffle(nodes) | ||
return nodes | ||
|
||
|
||
def run_test(base_url, nodes): | ||
data = { | ||
"input_ids": [node["input_ids"] for node in nodes], | ||
"sampling_params": [ | ||
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes | ||
], | ||
} | ||
|
||
res = requests.post(base_url + "/generate", json=data) | ||
assert res.status_code == 200 | ||
|
||
|
||
class TestRadixCacheFCFS(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST | ||
cls.base_url = DEFAULT_URL_FOR_TEST | ||
cls.process = popen_launch_server( | ||
cls.model, | ||
cls.base_url, | ||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
other_args=[ | ||
"--chunked-prefill-size", | ||
"128", | ||
"--max-total-tokens", | ||
"20000", | ||
"--schedule-policy", | ||
"fcfs", | ||
], | ||
) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
kill_child_process(cls.process.pid, include_self=True) | ||
|
||
def test_radix_attention(self): | ||
nodes = gen_radix_tree() | ||
run_test(self.base_url, nodes) | ||
|
||
|
||
class TestRadixCacheLPM(TestRadixCacheFCFS): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST | ||
cls.base_url = DEFAULT_URL_FOR_TEST | ||
cls.process = popen_launch_server( | ||
cls.model, | ||
cls.base_url, | ||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
other_args=[ | ||
"--chunked-prefill-size", | ||
"128", | ||
"--max-total-tokens", | ||
"20000", | ||
"--schedule-policy", | ||
"lpm", | ||
], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
os.environ["SGLANG_TEST_RETRACT"] = "true" | ||
unittest.main() |