Skip to content

Commit

Permalink
Simplify the interface of tp_worker (#1718)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 20, 2024
1 parent b6cd903 commit 12cad0f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
53 changes: 25 additions & 28 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
):
# Parse args
self.server_args = server_args
Expand Down Expand Up @@ -144,13 +145,24 @@ def __init__(

# Launch a tensor parallel worker
self.tp_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.device = self.tp_worker.device

# Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()

# Get token and memory info from the model worker
(
Expand All @@ -159,11 +171,11 @@ def __init__(
self.max_running_requests,
self.max_req_input_len,
self.random_seed,
self.device,
) = self.tp_worker.get_token_and_memory_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
set_random_seed(self.random_seed)
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)

# Print debug info
logger.info(
Expand All @@ -173,9 +185,8 @@ def __init__(
f"context_len={self.model_config.context_len}"
)

# Init cache
self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
# Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()

if (
server_args.chunked_prefill_size is not None
Expand Down Expand Up @@ -253,20 +264,6 @@ def __init__(
with_stack=True,
)

# Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.cache_finished_req = self.tree_cache.cache_finished_req
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.cache_finished_req = self.tree_cache.cache_finished_req

@torch.inference_mode()
def event_loop_normal(self):
self.last_batch = None
Expand Down Expand Up @@ -779,7 +776,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
req.check_finished()

if req.finished():
self.cache_finished_req(req)
self.tree_cache.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)

Expand Down Expand Up @@ -808,7 +805,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
req.check_finished()

if req.finished():
self.cache_finished_req(req)
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)

Expand Down Expand Up @@ -845,7 +842,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
)

if req.finished():
self.cache_finished_req(req)
self.tree_cache.cache_finished_req(req)

if req.return_logprob:
req.output_token_logprobs.append(
Expand Down Expand Up @@ -1069,7 +1066,7 @@ def abort_request(self, recv_req: AbortReq):
for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT()
self.cache_finished_req(req)
self.tree_cache.cache_finished_req(req)
break

def update_weights(self, recv_req: UpdateWeightReqInput):
Expand Down Expand Up @@ -1112,7 +1109,7 @@ def run_scheduler_process(
suppress_other_loggers()

try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send("ready")
if server_args.enable_overlap_schedule:
scheduler.event_loop_overlap()
Expand Down
17 changes: 16 additions & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import threading
import time
from queue import Queue
from typing import Optional

import torch

Expand All @@ -40,9 +41,10 @@ class TpModelWorker:

def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
dp_rank: Optional[int],
nccl_port: int,
):
# Parse args
Expand Down Expand Up @@ -116,6 +118,19 @@ def get_token_and_memory_info(self):
self.max_running_requests,
self.max_req_input_len,
self.random_seed,
self.device,
)

def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)

def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group

def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
)

def init_overlap_status(self):
Expand Down

0 comments on commit 12cad0f

Please sign in to comment.