From 29b9c42273ad5e3a4a8120b3a1a8f7b578be894a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 26 Dec 2024 10:16:55 +0100 Subject: [PATCH 1/9] Fix unclosed 'MemoryObjectReceiveStream' upon exception in 'BaseHTTPMiddleware' children Co-authored-by: Thomas Grainger <413772+graingert@users.noreply.github.com> Co-authored-by: Nikita Gashkov <8746283+nikitagashkov@users.noreply.github.com> --- pyproject.toml | 2 -- starlette/middleware/base.py | 33 ++++++++++++++------------------- starlette/testclient.py | 19 ++++++++++--------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0712fe5da..b7baa1258 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,8 +86,6 @@ filterwarnings = [ "ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning", "ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning", "ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning", - # TODO: This warning appeared when we bumped anyio to 4.4.0. - "ignore: Unclosed .MemoryObject(Send|Receive)Stream.:ResourceWarning", ] [tool.coverage.run] diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f51b13f73..435f3e8c4 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -3,7 +3,7 @@ import typing import anyio -from anyio.abc import ObjectReceiveStream, ObjectSendStream +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request @@ -107,9 +107,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: Exception | None = None - send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]] - recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] - send_stream, recv_stream = anyio.create_memory_object_stream() async def receive_or_disconnect() -> Message: if response_sent.is_set(): @@ -130,10 +127,6 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: return message - async def close_recv_stream_on_response_sent() -> None: - await response_sent.wait() - recv_stream.close() - async def send_no_error(message: Message) -> None: try: await send_stream.send(message) @@ -144,13 +137,12 @@ async def send_no_error(message: Message) -> None: async def coro() -> None: nonlocal app_exc - async with send_stream: + with send_stream: try: await self.app(scope, receive_or_disconnect, send_no_error) except Exception as exc: app_exc = exc - task_group.start_soon(close_recv_stream_on_response_sent) task_group.start_soon(coro) try: @@ -166,14 +158,13 @@ async def coro() -> None: assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: - async with recv_stream: - async for message in recv_stream: - assert message["type"] == "http.response.body" - body = message.get("body", b"") - if body: - yield body - if not message.get("more_body", False): - break + async for message in recv_stream: + assert message["type"] == "http.response.body" + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break if app_exc is not None: raise app_exc @@ -182,11 +173,15 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - with collapse_excgroups(): + send_stream: MemoryObjectSendStream[Message] + recv_stream: MemoryObjectReceiveStream[Message] + send_stream, recv_stream = anyio.create_memory_object_stream() + with recv_stream, send_stream, collapse_excgroups(): async with anyio.create_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() + recv_stream.close() async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: raise NotImplementedError() # pragma: no cover diff --git a/starlette/testclient.py b/starlette/testclient.py index 2c096aa22..94454e651 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -700,6 +700,8 @@ def reset_portal() -> None: receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send1, receive1 = anyio.create_memory_object_stream(math.inf) send2, receive2 = anyio.create_memory_object_stream(math.inf) + for channel in (send1, send2, receive1, receive2): + stack.callback(channel.close) self.stream_send = StapledObjectStream(send1, receive1) self.stream_receive = StapledObjectStream(send2, receive2) self.task = portal.start_task_soon(self.lifespan) @@ -747,12 +749,11 @@ async def receive() -> typing.Any: self.task.result() return message - async with self.stream_send, self.stream_receive: - await self.stream_receive.send({"type": "lifespan.shutdown"}) - message = await receive() - assert message["type"] in ( - "lifespan.shutdown.complete", - "lifespan.shutdown.failed", - ) - if message["type"] == "lifespan.shutdown.failed": - await receive() + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() From 2f317c06b90651d13e0059268dff91e06e8ac8ca Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:08:57 +0000 Subject: [PATCH 2/9] fix race condition in queue shutdown --- starlette/testclient.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 94454e651..0c66737e4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import enum import inspect import io import json @@ -85,6 +86,14 @@ class WebSocketDenialResponse( # type: ignore[misc] """ +class _Eof(enum.Enum): + EOF = enum.auto() + + +EOF: typing.Final = _Eof.EOF +Eof = typing.Literal[_Eof.EOF] + + class WebSocketTestSession: def __init__( self, @@ -97,7 +106,7 @@ def __init__( self.accepted_subprotocol = None self.portal_factory = portal_factory self._receive_queue: queue.Queue[Message] = queue.Queue() - self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() + self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() self.extra_headers = None def __enter__(self) -> WebSocketTestSession: @@ -129,8 +138,11 @@ def __exit__(self, *args: typing.Any) -> None: finally: self.portal.start_task_soon(self._notify_close) self.exit_stack.close() - while not self._send_queue.empty(): + + while True: message = self._send_queue.get() + if message is EOF: + break if isinstance(message, BaseException): raise message @@ -150,10 +162,13 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: finally: tg.cancel_scope.cancel() - async with anyio.create_task_group() as tg: - tg.start_soon(run_app, tg) - await self.should_close.wait() - tg.cancel_scope.cancel() + try: + async with anyio.create_task_group() as tg: + tg.start_soon(run_app, tg) + await self.should_close.wait() + tg.cancel_scope.cancel() + finally: + self._send_queue.put(EOF) async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): @@ -202,6 +217,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: def receive(self) -> Message: message = self._send_queue.get() + assert message is not EOF if isinstance(message, BaseException): raise message return message From 8015a843687fc39688322253040b95a6e0716bf6 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:13:48 +0000 Subject: [PATCH 3/9] Update starlette/testclient.py --- starlette/testclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 0c66737e4..120aee02b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -168,7 +168,7 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: await self.should_close.wait() tg.cancel_scope.cancel() finally: - self._send_queue.put(EOF) + self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): From 60aa58d20e427b3169a6cb4f2aba9edea1a245d6 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 26 Dec 2024 11:21:57 +0100 Subject: [PATCH 4/9] Update starlette/middleware/base.py Co-authored-by: Thomas Grainger --- starlette/middleware/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 435f3e8c4..175ffc1f4 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -173,9 +173,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - send_stream: MemoryObjectSendStream[Message] - recv_stream: MemoryObjectReceiveStream[Message] - send_stream, recv_stream = anyio.create_memory_object_stream() + send_stream, recv_stream = anyio.create_memory_object_stream[Message]() with recv_stream, send_stream, collapse_excgroups(): async with anyio.create_task_group() as task_group: response = await self.dispatch_func(request, call_next) From f3868f80bdc92ff83dce5f86c2224dcbf2d38316 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 26 Dec 2024 11:24:47 +0100 Subject: [PATCH 5/9] Update starlette/middleware/base.py Co-authored-by: Thomas Grainger --- starlette/middleware/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 175ffc1f4..6e37c6f60 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -3,7 +3,6 @@ import typing import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request From 28bc2e68b39df564fbff12cb30114c25f42978d2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:25:04 +0000 Subject: [PATCH 6/9] Update testclient.py --- starlette/testclient.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 120aee02b..0bed20dcd 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -710,12 +710,8 @@ def __enter__(self) -> TestClient: def reset_portal() -> None: self.portal = None - send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None] - receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None] - send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] - receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] - send1, receive1 = anyio.create_memory_object_stream(math.inf) - send2, receive2 = anyio.create_memory_object_stream(math.inf) + send1, receive1 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None](math.inf) + send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf) for channel in (send1, send2, receive1, receive2): stack.callback(channel.close) self.stream_send = StapledObjectStream(send1, receive1) From 66e1bf1e3db955d19caa1b04be74a06ebcab0eef Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:25:43 +0000 Subject: [PATCH 7/9] Update testclient.py --- starlette/testclient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 0bed20dcd..007da30f7 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -17,7 +17,6 @@ import anyio import anyio.abc import anyio.from_thread -from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable From e0566712da9d7f514653937c15baa933d7b487fe Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:30:03 +0000 Subject: [PATCH 8/9] weird mypy bug with | None --- starlette/testclient.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 007da30f7..718beba2f 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -709,7 +709,9 @@ def __enter__(self) -> TestClient: def reset_portal() -> None: self.portal = None - send1, receive1 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None](math.inf) + send1, receive1 = anyio.create_memory_object_stream[ + typing.Optional[typing.MutableMapping[str, typing.Any]] + ](math.inf) send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf) for channel in (send1, send2, receive1, receive2): stack.callback(channel.close) From bf9082badaa0cfb0b565f6107b351e62c3ea9a03 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 11:14:59 +0000 Subject: [PATCH 9/9] use typing.Union instead of Optional --- starlette/testclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index d8fc20b46..5811c6e39 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -691,7 +691,7 @@ def reset_portal() -> None: self.portal = None send1, receive1 = anyio.create_memory_object_stream[ - typing.Optional[typing.MutableMapping[str, typing.Any]] + typing.Union[typing.MutableMapping[str, typing.Any], None] ](math.inf) send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf) for channel in (send1, send2, receive1, receive2):