diff --git a/python/sglang/bench_server_latency.py b/python/sglang/bench_server_latency.py index 66e59d0d4e0..57506913f51 100644 --- a/python/sglang/bench_server_latency.py +++ b/python/sglang/bench_server_latency.py @@ -17,7 +17,7 @@ import multiprocessing import os import time -from typing import Optional, Tuple +from typing import Tuple import numpy as np import requests diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 40463d01610..e3fea477788 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -775,7 +775,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): else: self.tree_cache.cache_unfinished_req(req) - self.stream_output(batch) + self.stream_output(batch.reqs) def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids = result @@ -815,7 +815,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) - self.stream_output(batch) + self.stream_output(batch.reqs) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: @@ -894,7 +894,7 @@ def add_logprob_return_values( return num_input_logprobs - def stream_output(self, batch: ScheduleBatch): + def stream_output(self, reqs: List[Req]): output_rids = [] output_meta_info = [] output_finished_reason: List[BaseFinishReason] = [] @@ -911,7 +911,7 @@ def stream_output(self, batch: ScheduleBatch): is_stream_iter = self.decode_forward_ct % self.stream_interval == 0 - for req in batch.reqs: + for req in reqs: if req.finished() or ( req.stream and (is_stream_iter or len(req.output_ids) == 1) ): @@ -1025,8 +1025,9 @@ def abort_request(self, recv_req: AbortReq): # Delete requests in the running batch if self.running_batch: for req in self.running_batch.reqs: - if req.rid == recv_req.rid: + if req.rid == recv_req.rid and not req.finished(): req.finished_reason = FINISH_ABORT() + self.tree_cache.cache_finished_req(req) break def update_weights(self, recv_req: UpdateWeightReqInput):