Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unclosed 'MemoryObjectReceiveStream' upon exception in 'BaseHTTPMiddleware' children #2813

Merged
merged 11 commits into from
Dec 29, 2024
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ filterwarnings = [
"ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning",
"ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning",
"ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning",
# TODO: This warning appeared when we bumped anyio to 4.4.0.
"ignore: Unclosed .MemoryObject(Send|Receive)Stream.:ResourceWarning",
]

[tool.coverage.run]
Expand Down
30 changes: 11 additions & 19 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
Expand Down Expand Up @@ -107,9 +106,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

async def call_next(request: Request) -> Response:
app_exc: Exception | None = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
Expand All @@ -130,10 +126,6 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

Comment on lines -133 to -136
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
Expand All @@ -144,13 +136,12 @@ async def send_no_error(message: Message) -> None:
async def coro() -> None:
nonlocal app_exc

async with send_stream:
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)

try:
Expand All @@ -166,14 +157,13 @@ async def coro() -> None:
assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break

if app_exc is not None:
raise app_exc
Expand All @@ -182,11 +172,13 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

with collapse_excgroups():
send_stream, recv_stream = anyio.create_memory_object_stream[Message]()
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
Expand Down
56 changes: 34 additions & 22 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
Expand All @@ -16,7 +17,6 @@
import anyio
import anyio.abc
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
Expand Down Expand Up @@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem this solves is not related to this PR then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes probably best to split into two PRs and add a changelog entry for the lost shutdown exception issue



class WebSocketTestSession:
def __init__(
self,
Expand All @@ -97,7 +105,7 @@ def __init__(
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> WebSocketTestSession:
Expand Down Expand Up @@ -129,8 +137,11 @@ def __exit__(self, *args: typing.Any) -> None:
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():

while True:
message = self._send_queue.get()
if message is EOF:
break
if isinstance(message, BaseException):
raise message

Expand All @@ -150,10 +161,13 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None:
finally:
tg.cancel_scope.cancel()

async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
try:
async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
Expand Down Expand Up @@ -202,6 +216,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
Expand Down Expand Up @@ -694,12 +709,10 @@ def __enter__(self) -> TestClient:
def reset_portal() -> None:
self.portal = None

send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send1, receive1 = anyio.create_memory_object_stream(math.inf)
send2, receive2 = anyio.create_memory_object_stream(math.inf)
send1, receive1 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None](math.inf)
send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf)
for channel in (send1, send2, receive1, receive2):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(send1, receive1)
self.stream_receive = StapledObjectStream(send2, receive2)
self.task = portal.start_task_soon(self.lifespan)
Expand Down Expand Up @@ -747,12 +760,11 @@ async def receive() -> typing.Any:
self.task.result()
return message

async with self.stream_send, self.stream_receive:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
Loading