Skip to content

Commit

Permalink
[Production] Drain requests before exit when receive SIGTERM (#1838)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Oct 30, 2024
1 parent 3184aa9 commit 4e2af03
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import json
import logging
import os
import signal
import sys
from typing import Dict, List, Optional, Tuple, Union

import fastapi
Expand Down Expand Up @@ -58,7 +60,12 @@
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Expand Down Expand Up @@ -142,6 +149,9 @@ def __init__(
self.model_update_lock = asyncio.Lock()
self.model_update_result = None

# Others
self.gracefully_exit = False

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
Expand Down Expand Up @@ -629,6 +639,28 @@ def create_handle_loop(self):
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())

signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog())

async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(60)

# drain requests
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
else:
break

kill_child_process(include_self=True)
sys.exit(-1)

async def handle_loop(self):
"""The event loop that handles requests"""

Expand Down Expand Up @@ -740,3 +772,14 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
token_top_logprobs, decode_to_text
)
return top_logprobs


class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager

def signal_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True

0 comments on commit 4e2af03

Please sign in to comment.