-
-
Notifications
You must be signed in to change notification settings - Fork 952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix unclosed 'MemoryObjectReceiveStream' upon exception in 'BaseHTTPMiddleware' children #2813
Changes from 7 commits
29b9c42
2f317c0
8015a84
60aa58d
f3868f8
28bc2e6
66e1bf1
e056671
de4b745
bf9082b
2493818
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
import enum | ||
import inspect | ||
import io | ||
import json | ||
|
@@ -16,7 +17,6 @@ | |
import anyio | ||
import anyio.abc | ||
import anyio.from_thread | ||
from anyio.abc import ObjectReceiveStream, ObjectSendStream | ||
from anyio.streams.stapled import StapledObjectStream | ||
|
||
from starlette._utils import is_async_callable | ||
|
@@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc] | |
""" | ||
|
||
|
||
class _Eof(enum.Enum): | ||
EOF = enum.auto() | ||
|
||
|
||
EOF: typing.Final = _Eof.EOF | ||
Eof = typing.Literal[_Eof.EOF] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem this solves is not related to this PR then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes probably best to split into two PRs and add a changelog entry for the lost shutdown exception issue |
||
|
||
|
||
class WebSocketTestSession: | ||
def __init__( | ||
self, | ||
|
@@ -97,7 +105,7 @@ def __init__( | |
self.accepted_subprotocol = None | ||
self.portal_factory = portal_factory | ||
self._receive_queue: queue.Queue[Message] = queue.Queue() | ||
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() | ||
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() | ||
self.extra_headers = None | ||
|
||
def __enter__(self) -> WebSocketTestSession: | ||
|
@@ -129,8 +137,11 @@ def __exit__(self, *args: typing.Any) -> None: | |
finally: | ||
self.portal.start_task_soon(self._notify_close) | ||
self.exit_stack.close() | ||
while not self._send_queue.empty(): | ||
|
||
while True: | ||
message = self._send_queue.get() | ||
if message is EOF: | ||
break | ||
if isinstance(message, BaseException): | ||
raise message | ||
|
||
|
@@ -150,10 +161,13 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: | |
finally: | ||
tg.cancel_scope.cancel() | ||
|
||
async with anyio.create_task_group() as tg: | ||
tg.start_soon(run_app, tg) | ||
await self.should_close.wait() | ||
tg.cancel_scope.cancel() | ||
try: | ||
async with anyio.create_task_group() as tg: | ||
tg.start_soon(run_app, tg) | ||
await self.should_close.wait() | ||
tg.cancel_scope.cancel() | ||
finally: | ||
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ | ||
|
||
async def _asgi_receive(self) -> Message: | ||
while self._receive_queue.empty(): | ||
|
@@ -202,6 +216,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: | |
|
||
def receive(self) -> Message: | ||
message = self._send_queue.get() | ||
assert message is not EOF | ||
if isinstance(message, BaseException): | ||
raise message | ||
return message | ||
|
@@ -694,12 +709,10 @@ def __enter__(self) -> TestClient: | |
def reset_portal() -> None: | ||
self.portal = None | ||
|
||
send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None] | ||
receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None] | ||
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] | ||
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] | ||
send1, receive1 = anyio.create_memory_object_stream(math.inf) | ||
send2, receive2 = anyio.create_memory_object_stream(math.inf) | ||
send1, receive1 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None](math.inf) | ||
send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf) | ||
for channel in (send1, send2, receive1, receive2): | ||
stack.callback(channel.close) | ||
self.stream_send = StapledObjectStream(send1, receive1) | ||
self.stream_receive = StapledObjectStream(send2, receive2) | ||
self.task = portal.start_task_soon(self.lifespan) | ||
|
@@ -747,12 +760,11 @@ async def receive() -> typing.Any: | |
self.task.result() | ||
return message | ||
|
||
async with self.stream_send, self.stream_receive: | ||
await self.stream_receive.send({"type": "lifespan.shutdown"}) | ||
message = await receive() | ||
assert message["type"] in ( | ||
"lifespan.shutdown.complete", | ||
"lifespan.shutdown.failed", | ||
) | ||
if message["type"] == "lifespan.shutdown.failed": | ||
await receive() | ||
await self.stream_receive.send({"type": "lifespan.shutdown"}) | ||
message = await receive() | ||
assert message["type"] in ( | ||
"lifespan.shutdown.complete", | ||
"lifespan.shutdown.failed", | ||
) | ||
if message["type"] == "lifespan.shutdown.failed": | ||
await receive() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this as per @graingert comment on https://matrix.to/#/!JfFIjeKHlqEVmAsxYP:gitter.im/$S0X8-1wH1qscoq6FQjVCVDgBL4Qk-5BcTnaQfqBCkGI?via=gitter.im&via=matrix.org&via=matrix.freyachat.eu