Skip to content

Commit

Permalink
Simplify batch update
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 24, 2024
1 parent 7921690 commit 67b30b3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 26 deletions.
4 changes: 3 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ class ScheduleBatch:
extend_lens: List[int] = None
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None

# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
Expand Down Expand Up @@ -722,7 +723,6 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens += running_bs

# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
Expand All @@ -732,6 +732,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)

def check_decode_mem(self):
Expand Down
42 changes: 18 additions & 24 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""

import dataclasses
import logging
import os
import threading
Expand Down Expand Up @@ -721,40 +720,30 @@ def check_memory(self):

def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req:
# Move the chunked request out of the batch
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
# Inflight request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False

if not self.last_batch.is_empty():
if self.running_batch is None:
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)

# Prefill first
# Run prefill first if possible
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
return new_batch

# Check memory
if self.running_batch is None:
return

# Run decode
before_bs = self.running_batch.batch_size()
self.update_running_batch()
if not self.running_batch:
self.batch_is_full = False
if self.running_batch is None:
return None
if before_bs != self.running_batch.batch_size():
self.batch_is_full = False
self.running_batch = self.update_running_batch(self.running_batch)
return self.running_batch

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Expand Down Expand Up @@ -866,15 +855,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

return new_batch

def update_running_batch(self):
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
"""Update the current running decoding batch."""
global test_retract
batch = self.running_batch

initial_bs = batch.batch_size()

batch.filter_batch()
if batch.is_empty():
self.running_batch = None
return
self.batch_is_full = False
return None

# Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Expand All @@ -900,11 +890,15 @@ def update_running_batch(self):
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
self.running_batch = None
return
self.batch_is_full = False
return None

if batch.batch_size() < initial_bs:
self.batch_is_full = False

# Update batch tensors
batch.prepare_for_decode(self.enable_overlap)
return batch

def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/test/few_shot_gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def few_shot_gsm8k(s, question):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--parallel", type=int, default=128)
Expand Down

0 comments on commit 67b30b3

Please sign in to comment.