diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ad56be197e7..c9f0ea676f2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -467,6 +467,7 @@ class ScheduleBatch: extend_lens: List[int] = None extend_num_tokens: int = None decoding_reqs: List[Req] = None + extend_logprob_start_lens: List[int] = None # For encoder-decoder encoder_cached: Optional[List[bool]] = None @@ -722,7 +723,6 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): self.merge_batch(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc - self.extend_num_tokens += running_bs # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( @@ -732,6 +732,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): ] ) self.extend_lens.extend([1] * running_bs) + self.extend_num_tokens += running_bs + # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) def check_decode_mem(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index dbe18c129a2..f6f6e72e025 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,7 +13,6 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" -import dataclasses import logging import os import threading @@ -721,40 +720,30 @@ def check_memory(self): def get_next_batch_to_run(self): # Merge the prefill batch into the running batch - if ( - self.last_batch - and not self.last_batch.forward_mode.is_decode() - and not self.last_batch.is_empty() - ): + if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.being_chunked_req: + # Move the chunked request out of the batch self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) self.tree_cache.cache_unfinished_req(self.being_chunked_req) - # Inflight request keeps its rid but will get a new req_pool_idx. + # Inflight request keeps its rid but will get a new req_pool_idx self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.batch_is_full = False + if not self.last_batch.is_empty(): if self.running_batch is None: self.running_batch = self.last_batch else: self.running_batch.merge_batch(self.last_batch) - # Prefill first + # Run prefill first if possible new_batch = self.get_new_batch_prefill() if new_batch is not None: return new_batch - # Check memory - if self.running_batch is None: - return - # Run decode - before_bs = self.running_batch.batch_size() - self.update_running_batch() - if not self.running_batch: - self.batch_is_full = False + if self.running_batch is None: return None - if before_bs != self.running_batch.batch_size(): - self.batch_is_full = False + self.running_batch = self.update_running_batch(self.running_batch) return self.running_batch def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: @@ -866,15 +855,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: return new_batch - def update_running_batch(self): + def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: """Update the current running decoding batch.""" global test_retract - batch = self.running_batch + + initial_bs = batch.batch_size() batch.filter_batch() if batch.is_empty(): - self.running_batch = None - return + self.batch_is_full = False + return None # Check if decode out of memory if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): @@ -900,11 +890,15 @@ def update_running_batch(self): jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): - self.running_batch = None - return + self.batch_is_full = False + return None + + if batch.batch_size() < initial_bs: + self.batch_is_full = False # Update batch tensors batch.prepare_for_decode(self.enable_overlap) + return batch def run_batch(self, batch: ScheduleBatch): """Run a batch.""" diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 8e6572da61a..9657e730084 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -135,7 +135,7 @@ def few_shot_gsm8k(s, question): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-shots", type=int, default=5) - parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--data-path", type=str) parser.add_argument("--num-questions", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--parallel", type=int, default=128)