Skip to content

Commit e7c296b

Browse files
committed
Fix ASGI event handling for long-lived connections
After body events are consumed for authentication, the middleware's _fake_receive function now delegates to the original receive callable instead of returning None. This allows downstream applications to properly receive lifecycle events like http.disconnect, enabling proper cleanup for SSE connections, streaming responses, and other long-lived HTTP connections. Adds test to verify that _fake_receive correctly delegates to original receive after body events are exhausted.
1 parent 21579c9 commit e7c296b

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

mauth_client/middlewares/asgi.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def __call__(
6262
scope_copy[ENV_APP_UUID] = signed.app_uuid
6363
scope_copy[ENV_AUTHENTIC] = True
6464
scope_copy[ENV_PROTOCOL_VERSION] = signed.protocol_version()
65-
await self.app(scope_copy, self._fake_receive(events), send)
65+
await self.app(scope_copy, self._fake_receive(events, receive), send)
6666
else:
6767
await self._send_response(send, status, message)
6868

@@ -100,18 +100,25 @@ async def _send_response(self, send: ASGISendCallable, status: int, msg: str) ->
100100
"body": json.dumps(body).encode("utf-8"),
101101
})
102102

103-
def _fake_receive(self, events: List[ASGIReceiveEvent]) -> ASGIReceiveCallable:
103+
def _fake_receive(self, events: List[ASGIReceiveEvent], original_receive: ASGIReceiveCallable) -> ASGIReceiveCallable:
104104
"""
105-
Create a fake, async receive function using an iterator of the events
106-
we've already read. This will be passed to downstream middlewares/apps
107-
instead of the usual receive fn, so that they can also "receive" the
108-
body events.
105+
Create a fake receive function that replays cached body events.
106+
107+
After the middleware consumes request body events for authentication,
108+
this allows downstream apps to also "receive" those events. Once all
109+
cached events are exhausted, delegates to the original receive to
110+
properly forward lifecycle events (like http.disconnect).
111+
112+
This is essential for long-lived connections (SSE, streaming responses)
113+
that need to detect client disconnects.
109114
"""
110115
events_iter = iter(events)
111116

112117
async def _receive() -> ASGIReceiveEvent:
113118
try:
114119
return next(events_iter)
115120
except StopIteration:
116-
pass
121+
# After body events are consumed, delegate to original receive
122+
# This allows proper handling of disconnects for SSE connections
123+
return await original_receive()
117124
return _receive

tests/middlewares/asgi_test.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import asyncio
12
import unittest
2-
from unittest.mock import patch
3-
43
from fastapi import FastAPI, Request
54
from fastapi.testclient import TestClient
65
from fastapi.websockets import WebSocket
6+
from unittest.mock import AsyncMock
7+
from unittest.mock import patch
78
from uuid import uuid4
89

910
from mauth_client.authenticator import LocalAuthenticator
@@ -220,3 +221,59 @@ def is_authentic_effect(self):
220221
self.client.get("/sub_app/path")
221222

222223
self.assertEqual(request_url, "/sub_app/path")
224+
225+
class TestMAuthASGIMiddlewareInSubApplication(unittest.IsolatedAsyncioTestCase):
226+
def setUp(self):
227+
self.app = FastAPI()
228+
Config.APP_UUID = str(uuid4())
229+
Config.MAUTH_URL = "https://mauth.com"
230+
Config.MAUTH_API_VERSION = "v1"
231+
Config.PRIVATE_KEY = "key"
232+
233+
@patch.object(LocalAuthenticator, "is_authentic")
234+
async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock):
235+
"""Test that after body events are consumed, _fake_receive delegates to original receive"""
236+
is_authentic_mock.return_value = (True, 200, "")
237+
238+
# Track that original receive was called after body events exhausted
239+
call_order = []
240+
241+
async def mock_app(scope, receive, send):
242+
# First receive should get body event
243+
event1 = await receive()
244+
call_order.append(("body", event1["type"]))
245+
246+
# Second receive should delegate to original receive
247+
event2 = await receive()
248+
call_order.append(("disconnect", event2["type"]))
249+
250+
await send({"type": "http.response.start", "status": 200, "headers": []})
251+
await send({"type": "http.response.body", "body": b""})
252+
253+
middleware = MAuthASGIMiddleware(mock_app)
254+
255+
# Mock receive that returns body then disconnect
256+
receive_calls = 0
257+
async def mock_receive():
258+
nonlocal receive_calls
259+
receive_calls += 1
260+
if receive_calls == 1:
261+
return {"type": "http.request", "body": b"test", "more_body": False}
262+
return {"type": "http.disconnect"}
263+
264+
send_mock = AsyncMock()
265+
scope = {
266+
"type": "http",
267+
"method": "POST",
268+
"path": "/test",
269+
"query_string": b"",
270+
"headers": []
271+
}
272+
273+
await middleware(scope, mock_receive, send_mock)
274+
275+
# Verify events were received in correct order
276+
self.assertEqual(len(call_order), 2)
277+
self.assertEqual(call_order[0], ("body", "http.request"))
278+
self.assertEqual(call_order[1], ("disconnect", "http.disconnect"))
279+
self.assertEqual(receive_calls, 2) # Called once for auth, once from app

0 commit comments

Comments
 (0)