Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

多线程内存共享 #1

Open
wants to merge 2 commits into
base: flex_scheduler
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/latency_throughput/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--dataset", type=str, help="Path to the dataset.")
parser.add_argument("--input-len", type=int, default=2048)
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--range-ratio", type=float, default=1.0)
parser.add_argument("--range-ratio", type=float, default=0.4)
parser.add_argument(
"--tokenizer",
type=str,
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.profile.process_scheduler import ProcessScheduler
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
Expand Down Expand Up @@ -98,6 +99,10 @@ def __init__(
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)

self.profile_scheduler = ProcessScheduler()


self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
Expand All @@ -106,9 +111,12 @@ def __init__(
tp_size=server_args.tp_size,
nccl_port=nccl_port,
server_args=server_args,
profile_scheduler=self.profile_scheduler,
)

if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None

else:
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
Expand Down
28 changes: 28 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
monkey_patch_vllm_p2p_access_check,
monkey_patch_vllm_qvk_linear_loader,
)
from sglang.srt.profile.process_scheduler import ProcessScheduler

logger = logging.getLogger(__name__)

Expand All @@ -73,8 +74,12 @@ def __init__(
tp_size: int,
nccl_port: int,
server_args: ServerArgs,
profile_scheduler: ProcessScheduler
):
# Parse args

self.profile_scheduler = profile_scheduler

self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
Expand Down Expand Up @@ -138,6 +143,9 @@ def __init__(
# Capture cuda graphs
self.init_cuda_graphs()

# 画图计数器:
self.plot_count = 0

def load_model(self):
logger.info(
f"[gpu={self.gpu_id}] Load weight begin. "
Expand Down Expand Up @@ -280,6 +288,7 @@ def init_memory_pool(
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)

logger.info(
f"[gpu={self.gpu_id}] Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
Expand Down Expand Up @@ -389,6 +398,16 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch):
)

def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
if batch.input_ids is not None:
self.plot_count += 1
p = self.profile_scheduler.create_process(write_batch_data, (self.profile_scheduler.get_shared_object(), batch.input_ids.numel(), self.token_to_kv_pool.available_size(), forward_mode, self.token_to_kv_pool.size))
self.profile_scheduler.start_process(p)
self.profile_scheduler.join_process(p)

# self.profile_scheduler.print_profile_data()
if self.plot_count % 100 == 0:
self.profile_scheduler.plot_profile_data()
self.profile_scheduler.print_profile_data()
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE:
Expand All @@ -399,6 +418,15 @@ def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
raise ValueError(f"Invaid forward mode: {forward_mode}")


def write_batch_data(shared_object, batch_data, mem_data, type, max_size):
shared_object.batch_data.append(batch_data)
shared_object.mem_data.append(mem_data)
shared_object.type.append(type)
shared_object.max_size = max_size




@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
Expand Down
75 changes: 75 additions & 0 deletions python/sglang/srt/profile/process_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import multiprocessing
import logging
# 创建一个logger
logger = logging.getLogger('profile')
logger.setLevel(logging.DEBUG)

# 创建一个handler,用于将日志写入文件
fh = logging.FileHandler('/data/home/josephyou/WXG_WORK/sglang/python/sglang/srt/profile/profile.log')
fh.setLevel(logging.DEBUG)


# 定义handler的输出格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)

# 给logger添加handler
logger.addHandler(fh)


class ProcessScheduler:
def __init__(self) -> None:
# 创建一个Manager对象
self.manager = multiprocessing.Manager()

# 创建一个ProfileData实例
self.shared_object = self.manager.Namespace()

self.shared_object.mem_data = self.manager.list()
self.shared_object.batch_data = self.manager.list()
self.shared_object.type = self.manager.list()
self.shared_object.max_size = 0.0


def create_process(self, target, args):
process = multiprocessing.Process(target=target, args=args)
return process

def start_process(self, process):
process.start()

def join_process(self, process):
process.join()

def get_shared_object(self):
return self.shared_object


def print_profile_data(self):


# 记录一些日志

logger.info(f"id\t\tmem\t\tbatch\t\ttype")
length = min(len(self.shared_object.mem_data), len(self.shared_object.batch_data), len(self.shared_object.type))
for i in range(0, length):
logger.info(f"{i}\t\t{self.shared_object.mem_data[i]}\t\t{self.shared_object.batch_data[i]}\t\t{self.shared_object.type[i]}")


def plot_profile_data(self):
import matplotlib.pyplot as plt

gpu_memory_usage_percentage = [(float(self.shared_object.max_size - usage) / self.shared_object.max_size) * 100 for usage in self.shared_object.mem_data]
compute_resource_usage = [(float(usage) / 512) * 100 for usage in self.shared_object.batch_data]

plt.figure(figsize=(10, 6))

# length = min(len(shared_object.mem_data), len(shared_object.batch_data), len(shared_object.type))
plt.plot(gpu_memory_usage_percentage, label='GPU Memory Usage (%)')
plt.plot(compute_resource_usage, label='Compute Resource Usage (%)')
plt.legend()
plt.xlabel("Steps")
plt.ylabel("Rate(%)")
plt.savefig('/data/home/josephyou/WXG_WORK/sglang/python/sglang/srt/profile/profile.png')

print("memroy use saved...")