Skip to content

Commit 9e32e8e

Browse files
Cooperative signal handling (#1600)
* test desired signal behaviour * capture and restore signal handlers * ruff * checks * test asyncio handlers * add note on signal handler handling * remove legacy signal raising * test SIGBREAK on windows * remove test guard * include convered branch * Update docs/index.md * Update docs/index.md --------- Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent f73b8be commit 9e32e8e

File tree

2 files changed

+92
-14
lines changed

2 files changed

+92
-14
lines changed

tests/test_server.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import contextlib
5+
import signal
6+
import sys
7+
from typing import Callable, ContextManager, Generator
8+
9+
import pytest
10+
11+
from uvicorn.config import Config
12+
from uvicorn.server import Server
13+
14+
15+
# asyncio does NOT allow raising in signal handlers, so to detect
16+
# raised signals raised a mutable `witness` receives the signal
17+
@contextlib.contextmanager
18+
def capture_signal_sync(sig: signal.Signals) -> Generator[list[int], None, None]:
19+
"""Replace `sig` handling with a normal exception via `signal"""
20+
witness: list[int] = []
21+
original_handler = signal.signal(sig, lambda signum, frame: witness.append(signum))
22+
yield witness
23+
signal.signal(sig, original_handler)
24+
25+
26+
@contextlib.contextmanager
27+
def capture_signal_async(sig: signal.Signals) -> Generator[list[int], None, None]: # pragma: py-win32
28+
"""Replace `sig` handling with a normal exception via `asyncio"""
29+
witness: list[int] = []
30+
original_handler = signal.getsignal(sig)
31+
asyncio.get_running_loop().add_signal_handler(sig, witness.append, sig)
32+
yield witness
33+
signal.signal(sig, original_handler)
34+
35+
36+
async def dummy_app(scope, receive, send): # pragma: py-win32
37+
pass
38+
39+
40+
if sys.platform == "win32":
41+
signals = [signal.SIGBREAK]
42+
signal_captures = [capture_signal_sync]
43+
else:
44+
signals = [signal.SIGTERM, signal.SIGINT]
45+
signal_captures = [capture_signal_sync, capture_signal_async]
46+
47+
48+
@pytest.mark.anyio
49+
@pytest.mark.parametrize("exception_signal", signals)
50+
@pytest.mark.parametrize("capture_signal", signal_captures)
51+
async def test_server_interrupt(
52+
exception_signal: signal.Signals, capture_signal: Callable[[signal.Signals], ContextManager[None]]
53+
): # pragma: py-win32
54+
"""Test interrupting a Server that is run explicitly inside asyncio"""
55+
56+
async def interrupt_running(srv: Server):
57+
while not srv.started:
58+
await asyncio.sleep(0.01)
59+
signal.raise_signal(exception_signal)
60+
61+
server = Server(Config(app=dummy_app, loop="asyncio"))
62+
asyncio.create_task(interrupt_running(server))
63+
with capture_signal(exception_signal) as witness:
64+
await server.serve()
65+
assert witness
66+
# set by the server's graceful exit handler
67+
assert server.should_exit

uvicorn/server.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import logging
56
import os
67
import platform
@@ -11,7 +12,7 @@
1112
import time
1213
from email.utils import formatdate
1314
from types import FrameType
14-
from typing import TYPE_CHECKING, Sequence, Union
15+
from typing import TYPE_CHECKING, Generator, Sequence, Union
1516

1617
import click
1718

@@ -57,11 +58,17 @@ def __init__(self, config: Config) -> None:
5758
self.force_exit = False
5859
self.last_notified = 0.0
5960

61+
self._captured_signals: list[int] = []
62+
6063
def run(self, sockets: list[socket.socket] | None = None) -> None:
6164
self.config.setup_event_loop()
6265
return asyncio.run(self.serve(sockets=sockets))
6366

6467
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
68+
with self.capture_signals():
69+
await self._serve(sockets)
70+
71+
async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
6572
process_id = os.getpid()
6673

6774
config = self.config
@@ -70,8 +77,6 @@ async def serve(self, sockets: list[socket.socket] | None = None) -> None:
7077

7178
self.lifespan = config.lifespan_class(config)
7279

73-
self.install_signal_handlers()
74-
7580
message = "Started server process [%d]"
7681
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
7782
logger.info(message, process_id, extra={"color_message": color_message})
@@ -302,22 +307,28 @@ async def _wait_tasks_to_complete(self) -> None:
302307
for server in self.servers:
303308
await server.wait_closed()
304309

305-
def install_signal_handlers(self) -> None:
310+
@contextlib.contextmanager
311+
def capture_signals(self) -> Generator[None, None, None]:
312+
# Signals can only be listened to from the main thread.
306313
if threading.current_thread() is not threading.main_thread():
307-
# Signals can only be listened to from the main thread.
314+
yield
308315
return
309-
310-
loop = asyncio.get_event_loop()
311-
316+
# always use signal.signal, even if loop.add_signal_handler is available
317+
# this allows to restore previous signal handlers later on
318+
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
312319
try:
313-
for sig in HANDLED_SIGNALS:
314-
loop.add_signal_handler(sig, self.handle_exit, sig, None)
315-
except NotImplementedError: # pragma: no cover
316-
# Windows
317-
for sig in HANDLED_SIGNALS:
318-
signal.signal(sig, self.handle_exit)
320+
yield
321+
finally:
322+
for sig, handler in original_handlers.items():
323+
signal.signal(sig, handler)
324+
# If we did gracefully shut down due to a signal, try to
325+
# trigger the expected behaviour now; multiple signals would be
326+
# done LIFO, see https://stackoverflow.com/questions/48434964
327+
for captured_signal in reversed(self._captured_signals):
328+
signal.raise_signal(captured_signal)
319329

320330
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
331+
self._captured_signals.append(sig)
321332
if self.should_exit and sig == signal.SIGINT:
322333
self.force_exit = True
323334
else:

0 commit comments

Comments
 (0)