Skip to content

Commit

Permalink
Fix memory leak during abort (#1674)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 15, 2024
1 parent 175afed commit f1088e0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/sglang/bench_server_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import multiprocessing
import os
import time
from typing import Optional, Tuple
from typing import Tuple

import numpy as np
import requests
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
else:
self.tree_cache.cache_unfinished_req(req)

self.stream_output(batch)
self.stream_output(batch.reqs)

def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
Expand Down Expand Up @@ -815,7 +815,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

self.stream_output(batch)
self.stream_output(batch.reqs)

self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
Expand Down Expand Up @@ -894,7 +894,7 @@ def add_logprob_return_values(

return num_input_logprobs

def stream_output(self, batch: ScheduleBatch):
def stream_output(self, reqs: List[Req]):
output_rids = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
Expand All @@ -911,7 +911,7 @@ def stream_output(self, batch: ScheduleBatch):

is_stream_iter = self.decode_forward_ct % self.stream_interval == 0

for req in batch.reqs:
for req in reqs:
if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1)
):
Expand Down Expand Up @@ -1025,8 +1025,9 @@ def abort_request(self, recv_req: AbortReq):
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req)
break

def update_weights(self, recv_req: UpdateWeightReqInput):
Expand Down

0 comments on commit f1088e0

Please sign in to comment.