Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use asyncio.run(..., loop_factory) to avoid asyncio.set_event_loop_policy #2130

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions tests/test_auto_detection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import contextlib
import importlib

import pytest

from uvicorn.config import Config
from uvicorn.loops.auto import auto_loop_setup
from uvicorn.loops.auto import auto_loop_factory
from uvicorn.main import ServerState
from uvicorn.protocols.http.auto import AutoHTTPProtocol
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol
Expand Down Expand Up @@ -33,10 +34,10 @@ async def app(scope, receive, send):


def test_loop_auto():
auto_loop_setup()
policy = asyncio.get_event_loop_policy()
assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy)
assert type(policy).__module__.startswith(expected_loop)
loop_factory = auto_loop_factory()
with contextlib.closing(loop_factory()) as loop:
assert isinstance(loop, asyncio.AbstractEventLoop)
assert type(loop).__module__.startswith(expected_loop)


@pytest.mark.anyio
Expand Down
85 changes: 85 additions & 0 deletions uvicorn/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import asyncio
import sys
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar

_T = TypeVar("_T")

if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):

def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)

else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("asyncio.run() cannot be called from a running event loop")

if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")

if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if sys.version_info >= (3, 9):
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()

def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
21 changes: 11 additions & 10 deletions uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
HTTPProtocolType = Literal["auto", "h11", "httptools"]
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
LifespanType = Literal["auto", "on", "off"]
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"]
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]

LOG_LEVELS: dict[str, int] = {
Expand All @@ -53,11 +53,11 @@
"on": "uvicorn.lifespan.on:LifespanOn",
"off": "uvicorn.lifespan.off:LifespanOff",
}
LOOP_SETUPS: dict[LoopSetupType, str | None] = {
LOOP_FACTORIES: dict[LoopFactoryType, str | None] = {
"none": None,
"auto": "uvicorn.loops.auto:auto_loop_setup",
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
"auto": "uvicorn.loops.auto:auto_loop_factory",
"asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
"uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
}
INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]

Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(
port: int = 8000,
uds: str | None = None,
fd: int | None = None,
loop: LoopSetupType = "auto",
loop: LoopFactoryType = "auto",
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
ws_max_size: int = 16 * 1024 * 1024,
Expand Down Expand Up @@ -471,10 +471,11 @@ def load(self) -> None:

self.loaded = True

def setup_event_loop(self) -> None:
loop_setup: Callable | None = import_from_string(LOOP_SETUPS[self.loop])
if loop_setup is not None:
loop_setup(use_subprocess=self.use_subprocess)
def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None:
loop_factory: Callable | None = import_from_string(LOOP_FACTORIES[self.loop])
if loop_factory is None:
return None
return loop_factory(use_subprocess=self.use_subprocess)

def bind_socket(self) -> socket.socket:
logger_args: list[str | int]
Expand Down
13 changes: 7 additions & 6 deletions uvicorn/loops/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import asyncio
import logging
import sys

logger = logging.getLogger("uvicorn.error")
from collections.abc import Callable


def asyncio_setup(use_subprocess: bool = False) -> None:
if sys.platform == "win32" and use_subprocess:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # pragma: full coverage
def asyncio_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
if sys.platform == "win32" and not use_subprocess:
return asyncio.ProactorEventLoop
return asyncio.SelectorEventLoop
16 changes: 11 additions & 5 deletions uvicorn/loops/auto.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
def auto_loop_setup(use_subprocess: bool = False) -> None:
from __future__ import annotations

import asyncio
from collections.abc import Callable


def auto_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
try:
import uvloop # noqa
except ImportError: # pragma: no cover
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory

loop_setup(use_subprocess=use_subprocess)
return loop_factory(use_subprocess=use_subprocess)
else: # pragma: no cover
from uvicorn.loops.uvloop import uvloop_setup
from uvicorn.loops.uvloop import uvloop_loop_factory

uvloop_setup(use_subprocess=use_subprocess)
return uvloop_loop_factory(use_subprocess=use_subprocess)
7 changes: 5 additions & 2 deletions uvicorn/loops/uvloop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable

import uvloop


def uvloop_setup(use_subprocess: bool = False) -> None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uvloop.EventLoopPolicy is deprecated

def uvloop_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
return uvloop.new_event_loop
10 changes: 5 additions & 5 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
LIFESPAN,
LOG_LEVELS,
LOGGING_CONFIG,
LOOP_SETUPS,
LOOP_FACTORIES,
SSL_PROTOCOL_VERSION,
WS_PROTOCOLS,
Config,
HTTPProtocolType,
InterfaceType,
LifespanType,
LoopSetupType,
LoopFactoryType,
WSProtocolType,
)
from uvicorn.server import Server, ServerState # noqa: F401 # Used to be defined here.
Expand All @@ -36,7 +36,7 @@
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
LOOP_CHOICES = click.Choice([key for key in LOOP_FACTORIES.keys() if key != "none"])
INTERFACE_CHOICES = click.Choice(INTERFACES)

STARTUP_FAILURE = 3
Expand Down Expand Up @@ -364,7 +364,7 @@ def main(
port: int,
uds: str,
fd: int,
loop: LoopSetupType,
loop: LoopFactoryType,
http: HTTPProtocolType,
ws: WSProtocolType,
ws_max_size: int,
Expand Down Expand Up @@ -465,7 +465,7 @@ def run(
port: int = 8000,
uds: str | None = None,
fd: int | None = None,
loop: LoopSetupType = "auto",
loop: LoopFactoryType = "auto",
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
ws_max_size: int = 16777216,
Expand Down
4 changes: 2 additions & 2 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import click

from uvicorn._compat import asyncio_run
from uvicorn.config import Config

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,8 +62,7 @@ def __init__(self, config: Config) -> None:
self._captured_signals: list[int] = []

def run(self, sockets: list[socket.socket] | None = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))
return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())

async def serve(self, sockets: list[socket.socket] | None = None) -> None:
with self.capture_signals():
Expand Down
7 changes: 2 additions & 5 deletions uvicorn/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from gunicorn.arbiter import Arbiter
from gunicorn.workers.base import Worker

from uvicorn._compat import asyncio_run
from uvicorn.config import Config
from uvicorn.main import Server

Expand Down Expand Up @@ -70,10 +71,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

self.config = Config(**config_kwargs)

def init_process(self) -> None:
self.config.setup_event_loop()
super().init_process()

def init_signals(self) -> None:
# Reset signals so Gunicorn doesn't swallow subprocess return codes
# other signals are set up by Server.install_signal_handlers()
Expand Down Expand Up @@ -104,7 +101,7 @@ async def _serve(self) -> None:
sys.exit(Arbiter.WORKER_BOOT_ERROR)

def run(self) -> None:
return asyncio.run(self._serve())
return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory())

async def callback_notify(self) -> None:
self.notify()
Expand Down
Loading