From ae86e1d91d64c52f0db583ea2444cdcfb5472a92 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 12:46:15 +0000 Subject: [PATCH 01/11] collapse exceptions groups from streaming response --- starlette/responses.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 31874f655..c522e7f23 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -18,6 +18,7 @@ import anyio import anyio.to_thread +from starlette._utils import collapse_excgroups from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders @@ -258,14 +259,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: except OSError: raise ClientDisconnect() else: - async with anyio.create_task_group() as task_group: + with collapse_excgroups(): + async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() From a4687d7c5fce21a164925629838b5bccf4f3acdd Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 14:08:08 +0000 Subject: [PATCH 02/11] collapse only one level of excg --- starlette/_utils.py | 29 +++++++++++++++++++++-------- tests/middleware/test_base.py | 17 +++++++++++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index 0c389dcb2..e93250163 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -13,12 +13,14 @@ else: # pragma: no cover from typing_extensions import TypeGuard -has_exceptiongroups = True if sys.version_info < (3, 11): # pragma: no cover try: from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] except ImportError: - has_exceptiongroups = False + + class BaseExceptionGroup(BaseException): # type: ignore[no-redef] + pass + T = typing.TypeVar("T") AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] @@ -74,12 +76,23 @@ async def __aexit__(self, *args: typing.Any) -> None | bool: def collapse_excgroups() -> typing.Generator[None, None, None]: try: yield - except BaseException as exc: - if has_exceptiongroups: # pragma: no cover - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: - exc = exc.exceptions[0] - - raise exc + except BaseExceptionGroup as excs: + if len(excs.exceptions) != 1: + raise + + exc = excs.exceptions[0] + context = exc.__context__ + tb = exc.__traceback__ + cause = exc.__cause__ + sc = exc.__suppress_context__ + try: + raise exc + finally: + exc.__traceback__ = tb + exc.__context__ = context + exc.__cause__ = cause + exc.__suppress_context__ = sc + del exc, cause, tb, context def get_route_path(scope: Scope) -> str: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 7232cfd18..c2cecf48a 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -8,6 +8,7 @@ import anyio import pytest from anyio.abc import TaskStatus +from exceptiongroup import ExceptionGroup from starlette.applications import Starlette from starlette.background import BackgroundTask @@ -41,6 +42,10 @@ def exc(request: Request) -> None: raise Exception("Exc") +def eg(request: Request) -> None: + raise ExceptionGroup("my exception group", [ValueError("TEST")]) + + def exc_stream(request: Request) -> StreamingResponse: return StreamingResponse(_generate_faulty_stream()) @@ -76,6 +81,7 @@ async def websocket_endpoint(session: WebSocket) -> None: routes=[ Route("/", endpoint=homepage), Route("/exc", endpoint=exc), + Route("/eg", endpoint=eg), Route("/exc-stream", endpoint=exc_stream), Route("/no-response", endpoint=NoResponse), WebSocketRoute("/ws", endpoint=websocket_endpoint), @@ -89,13 +95,16 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception) as ctx: + with pytest.raises(Exception) as ctx1: response = client.get("/exc") - assert str(ctx.value) == "Exc" + assert str(ctx1.value) == "Exc" - with pytest.raises(Exception) as ctx: + with pytest.raises(Exception) as ctx2: response = client.get("/exc-stream") - assert str(ctx.value) == "Faulty Stream" + assert str(ctx2.value) == "Faulty Stream" + + with pytest.raises(ExceptionGroup, match=r"my exception group \(1 sub-exception\)"): + client.get("/eg") with pytest.raises(RuntimeError): response = client.get("/no-response") From ffacdb260efe14eb32c271dcf2260ac7eaff2cf6 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 14:14:29 +0000 Subject: [PATCH 03/11] conditional import --- tests/middleware/test_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index c2cecf48a..03bde0b0a 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextvars +import sys from collections.abc import AsyncGenerator, AsyncIterator, Generator from contextlib import AsyncExitStack from typing import Any @@ -8,7 +9,6 @@ import anyio import pytest from anyio.abc import TaskStatus -from exceptiongroup import ExceptionGroup from starlette.applications import Starlette from starlette.background import BackgroundTask @@ -22,6 +22,9 @@ from starlette.websockets import WebSocket from tests.types import TestClientFactory +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( From 02c221f346ca2f007e843e48f85e854e24ba87be Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 14:23:47 +0000 Subject: [PATCH 04/11] avoid direct use of collapse_excgroups --- starlette/_utils.py | 11 ++++++++++- starlette/middleware/base.py | 6 +++--- starlette/middleware/wsgi.py | 3 ++- starlette/responses.py | 15 +++++++-------- tests/middleware/test_wsgi.py | 3 +-- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index e93250163..ca28c6fa5 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -4,7 +4,9 @@ import functools import sys import typing -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager + +import anyio.abc from starlette.types import Scope @@ -95,6 +97,13 @@ def collapse_excgroups() -> typing.Generator[None, None, None]: del exc, cause, tb, context +@asynccontextmanager +async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]: + with collapse_excgroups(): + async with anyio.create_task_group() as tg: + yield tg + + def get_route_path(scope: Scope) -> str: path: str = scope["path"] root_path = scope.get("root_path", "") diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 6e37c6f60..77404f790 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ import anyio -from starlette._utils import collapse_excgroups +from starlette._utils import create_collapsing_task_group from starlette.requests import ClientDisconnect, Request from starlette.responses import AsyncContentStream, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -173,8 +173,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: return response send_stream, recv_stream = anyio.create_memory_object_stream[Message]() - with recv_stream, send_stream, collapse_excgroups(): - async with anyio.create_task_group() as task_group: + with recv_stream, send_stream: + async with create_collapsing_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 6e0a3fae6..3e9ad0296 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -9,6 +9,7 @@ import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream +from starlette._utils import create_collapsing_task_group from starlette.types import Receive, Scope, Send warnings.warn( @@ -102,7 +103,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - async with anyio.create_task_group() as task_group: + async with create_collapsing_task_group() as task_group: task_group.start_soon(self.sender, send) async with self.stream_send: await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) diff --git a/starlette/responses.py b/starlette/responses.py index c522e7f23..1f8b87bea 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -18,7 +18,7 @@ import anyio import anyio.to_thread -from starlette._utils import collapse_excgroups +from starlette._utils import create_collapsing_task_group from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders @@ -259,15 +259,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: except OSError: raise ClientDisconnect() else: - with collapse_excgroups(): - async with anyio.create_task_group() as task_group: + async with create_collapsing_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 3511c89c9..418f0946f 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,7 +4,6 @@ import pytest -from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ from tests.types import TestClientFactory @@ -86,7 +85,7 @@ def test_wsgi_exception(test_client_factory: TestClientFactory) -> None: # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(RuntimeError), collapse_excgroups(): + with pytest.raises(RuntimeError): client.get("/") From c7c6206d10e64ae4947cafd9890e3b00929389c8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 14:24:42 +0000 Subject: [PATCH 05/11] make private --- starlette/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index ca28c6fa5..2d6c70242 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -75,7 +75,7 @@ async def __aexit__(self, *args: typing.Any) -> None | bool: @contextmanager -def collapse_excgroups() -> typing.Generator[None, None, None]: +def _collapse_excgroups() -> typing.Generator[None, None, None]: try: yield except BaseExceptionGroup as excs: @@ -99,7 +99,7 @@ def collapse_excgroups() -> typing.Generator[None, None, None]: @asynccontextmanager async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]: - with collapse_excgroups(): + with _collapse_excgroups(): async with anyio.create_task_group() as tg: yield tg From 802409afe3e038856aa8e87543ecc902c5400231 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 14:43:44 +0000 Subject: [PATCH 06/11] remove incorrect mypy pin --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 95f195c50..223b6319f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ combine-as-imports = true [tool.mypy] strict = true -python_version = "3.9" [[tool.mypy.overrides]] module = "starlette.testclient.*" From 5b43511fd25391e941d2e6a88dab0d94067017fb Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 15:00:01 +0000 Subject: [PATCH 07/11] Update tests/middleware/test_base.py --- tests/middleware/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 03bde0b0a..483dd869e 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -22,7 +22,7 @@ from starlette.websockets import WebSocket from tests.types import TestClientFactory -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import ExceptionGroup From 6500d2f8a8ab0c00c29dd608751100305a0686d2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 14:45:14 +0000 Subject: [PATCH 08/11] remove unused ignore --- starlette/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index 2d6c70242..daa17d8c4 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -17,7 +17,7 @@ if sys.version_info < (3, 11): # pragma: no cover try: - from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] + from exceptiongroup import BaseExceptionGroup except ImportError: class BaseExceptionGroup(BaseException): # type: ignore[no-redef] From 5e4a262623a085bc506f2ca280adb4b09a55c879 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 15:31:41 +0000 Subject: [PATCH 09/11] add coverage for collapsing --- tests/test__utils.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/test__utils.py b/tests/test__utils.py index 916f460d4..e4652ff1b 100644 --- a/tests/test__utils.py +++ b/tests/test__utils.py @@ -1,11 +1,15 @@ import functools +import sys from typing import Any import pytest -from starlette._utils import get_route_path, is_async_callable +from starlette._utils import create_collapsing_task_group, get_route_path, is_async_callable from starlette.types import Scope +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroups import ExceptionGroup + def test_async_func() -> None: async def async_func() -> None: ... # pragma: no cover @@ -94,3 +98,31 @@ async def async_func( ) def test_get_route_path(scope: Scope, expected_result: str) -> None: assert get_route_path(scope) == expected_result + + +@pytest.mark.anyio +async def test_collapsing_task_group_one_exc() -> None: + class MyException(Exception): + pass + + with pytest.raises(MyException): + async with create_collapsing_task_group(): + raise MyException + + +@pytest.mark.anyio +async def test_collapsing_task_group_two_exc() -> None: + class MyException(Exception): + pass + + async def raise_exc() -> None: + raise MyException + + with pytest.raises(ExceptionGroup) as exc: + async with create_collapsing_task_group() as task_group: + task_group.start_soon(raise_exc) + raise MyException + + exc1, exc2 = exc.value.exceptions + assert isinstance(exc1, MyException) + assert isinstance(exc2, MyException) From e44b25429e83d7817c9e63ce53fa770ef8be72c2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 15:33:30 +0000 Subject: [PATCH 10/11] optimize --- starlette/_utils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/starlette/_utils.py b/starlette/_utils.py index daa17d8c4..143c660f3 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -4,7 +4,7 @@ import functools import sys import typing -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager import anyio.abc @@ -74,10 +74,11 @@ async def __aexit__(self, *args: typing.Any) -> None | bool: return None -@contextmanager -def _collapse_excgroups() -> typing.Generator[None, None, None]: +@asynccontextmanager +async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]: try: - yield + async with anyio.create_task_group() as tg: + yield tg except BaseExceptionGroup as excs: if len(excs.exceptions) != 1: raise @@ -97,13 +98,6 @@ def _collapse_excgroups() -> typing.Generator[None, None, None]: del exc, cause, tb, context -@asynccontextmanager -async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]: - with _collapse_excgroups(): - async with anyio.create_task_group() as tg: - yield tg - - def get_route_path(scope: Scope) -> str: path: str = scope["path"] root_path = scope.get("root_path", "") From e24198f2832ea1edb0b4ed6e6708aa6f9f4bd463 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 15:47:00 +0000 Subject: [PATCH 11/11] Update tests/test__utils.py --- tests/test__utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test__utils.py b/tests/test__utils.py index e4652ff1b..60af81392 100644 --- a/tests/test__utils.py +++ b/tests/test__utils.py @@ -8,7 +8,7 @@ from starlette.types import Scope if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroups import ExceptionGroup + from exceptiongroup import ExceptionGroup def test_async_func() -> None: