diff --git a/channels/routing.py b/channels/routing.py index f48c4d339..300538117 100644 --- a/channels/routing.py +++ b/channels/routing.py @@ -136,6 +136,46 @@ async def __call__(self, scope, receive, send): raise ValueError("No route found for path %r." % path) +class ValidURLRouter(URLRouter): + """ + URLRouter variant that returns 404 or closes WebSocket on invalid routes. + + Catches ValueError and Resolver404 from URL resolution. + + - For HTTP, responds with 404. + - For WebSocket, closes with code 1008 before handshake (resulting in 403). + - Other scope types propagate the exception. + """ + + async def __call__(self, scope, receive, send): + try: + return await super().__call__(scope, receive, send) + except (ValueError, Resolver404): + if scope["type"] == "http": + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [(b"content-type", b"text/plain")], + } + ) + await send( + { + "type": "http.response.body", + "body": b"404 Not Found", + } + ) + elif scope["type"] == "websocket": + await send( + { + "type": "websocket.close", + "code": 1008, + } + ) + else: + raise + + class ChannelNameRouter: """ Maps to different applications based on a "channel" key in the scope diff --git a/tests/test_testing.py b/tests/test_testing.py index fbfbf436f..5b6f33f90 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -6,7 +6,7 @@ from channels.consumer import AsyncConsumer from channels.generic.websocket import WebsocketConsumer -from channels.routing import URLRouter +from channels.routing import URLRouter, ValidURLRouter from channels.testing import HttpCommunicator, WebsocketCommunicator @@ -194,3 +194,44 @@ async def test_connection_scope(path): connected, _ = await communicator.connect() assert connected await communicator.disconnect() + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_route_validator_http(): + """ + Ensures ValidURLRouter returns 404 when route can't be matched. + """ + router = ValidURLRouter([path("test/", SimpleHttpApp())]) + communicator = HttpCommunicator(router, "GET", "/test/?foo=bar") + response = await communicator.get_response() + assert response["body"] == b"test response" + assert response["status"] == 200 + + communicator = HttpCommunicator(router, "GET", "/not-test/") + response = await communicator.get_response() + assert response["body"] == b"404 Not Found" + assert response["status"] == 404 + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_route_validator_websocket(): + """ + Ensures WebSocket connections are closed on unmatched routes. + + Forces ValidURLRouter to return 403 for unmatched routes during the handshake. + WebSocket clients will receive a 1008 close code. + + Ideally this should result in a 404, but that is not achievable in this context. + """ + router = ValidURLRouter([path("testws/", SimpleWebsocketApp())]) + communicator = WebsocketCommunicator(router, "/testws/") + connected, subprotocol = await communicator.connect() + assert connected + assert subprotocol is None + + communicator = WebsocketCommunicator(router, "/not-testws/") + connected, subprotocol = await communicator.connect() + assert connected is False + assert subprotocol == 1008