diff --git a/starlette/websockets.py b/starlette/websockets.py index fb76361c8..98396f30e 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -54,7 +54,9 @@ async def receive(self) -> Message: self.client_state = WebSocketState.DISCONNECTED return message else: - raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') + raise WebSocketDisconnect( + code=1006, reason='Cannot call "receive" once a disconnect message has been received.' + ) async def send(self, message: Message) -> None: """ @@ -95,7 +97,7 @@ async def send(self, message: Message) -> None: self.application_state = WebSocketState.DISCONNECTED await self._send(message) else: - raise RuntimeError('Cannot call "send" once a close message has been sent.') + raise WebSocketDisconnect(code=1006, reason='Cannot call "send" once a close message has been sent.') async def accept( self, diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e76d8f29b..c6c32028d 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -449,9 +449,10 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.close() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: no cover + assert exc.value.code == 1006 def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None: @@ -463,9 +464,10 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: message = await websocket.receive() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/") as websocket: websocket.close() + assert exc.value.code == 1006 def test_websocket_scope_interface() -> None: