diff --git a/docs/lifespan.md b/docs/lifespan.md index a5a766251..6411049b6 100644 --- a/docs/lifespan.md +++ b/docs/lifespan.md @@ -73,6 +73,58 @@ app = Starlette( The `state` received on the requests is a **shallow** copy of the state received on the lifespan handler. +## Accessing State + +The state can be accessed using either attribute-style or dictionary-style syntax. + +The dictionary-style syntax was introduced in Starlette 0.52.0 (January 2026), with the idea of +improving type safety when using the lifespan state, given that `Request` became a generic over +the state type. + +```python +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import TypedDict + +import httpx +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from starlette.routing import Route + + +class State(TypedDict): + http_client: httpx.AsyncClient + + +@asynccontextmanager +async def lifespan(app: Starlette) -> AsyncIterator[State]: + async with httpx.AsyncClient() as client: + yield {"http_client": client} + + +async def homepage(request: Request[State]) -> PlainTextResponse: + client = request.state["http_client"] + + reveal_type(client) # Revealed type is 'httpx.AsyncClient' + + response = await client.get("https://www.example.com") + return PlainTextResponse(response.text) + +app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)]) +``` + +!!! note + There were many attempts to make this work with attribute-style access instead of + dictionary-style access, but none were satisfactory, given they would have been + breaking changes, or there were typing limitations. + + For more details, see: + + - [@Kludex/starlette#issues/3005](https://github.com/Kludex/starlette/issues/3005) + - [@python/typing#discussions/1457](https://github.com/python/typing/discussions/1457) + - [@Kludex/starlette#pull/3036](https://github.com/Kludex/starlette/pull/3036) + ## Running lifespan in tests You should use `TestClient` as a context manager, to ensure that the lifespan is called. diff --git a/starlette/datastructures.py b/starlette/datastructures.py index e621900ae..615874c9c 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -689,3 +689,18 @@ def __getattr__(self, key: Any) -> Any: def __delattr__(self, key: Any) -> None: del self._state[key] + + def __getitem__(self, key: str) -> Any: + return self._state[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._state[key] = value + + def __delitem__(self, key: str) -> None: + del self._state[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._state) + + def __len__(self) -> int: + return len(self._state) diff --git a/starlette/requests.py b/starlette/requests.py index 99e8b9e19..041ef25a2 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -3,9 +3,10 @@ import json from collections.abc import AsyncGenerator, Iterator, Mapping from http import cookies as http_cookies -from typing import TYPE_CHECKING, Any, NoReturn, cast +from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast import anyio +from typing_extensions import TypeVar from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State @@ -68,7 +69,10 @@ class ClientDisconnect(Exception): pass -class HTTPConnection(Mapping[str, Any]): +StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State) + + +class HTTPConnection(Mapping[str, Any], Generic[StateT]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. @@ -172,14 +176,14 @@ def user(self) -> Any: return self.scope["user"] @property - def state(self) -> State: + def state(self) -> StateT: if not hasattr(self, "_state"): # Ensure 'state' has an empty dict if it's not already populated. self.scope.setdefault("state", {}) # Create a state instance with a reference to the dict in which it should # store info self._state = State(self.scope["state"]) - return self._state + return cast(StateT, self._state) def url_for(self, name: str, /, **path_params: Any) -> URL: url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") @@ -197,7 +201,7 @@ async def empty_send(message: Message) -> NoReturn: raise RuntimeError("Send channel has not been made available") -class Request(HTTPConnection): +class Request(HTTPConnection[StateT]): _form: FormData | None def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): diff --git a/tests/test_applications.py b/tests/test_applications.py index 20ff06385..30e1d7504 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator from contextlib import asynccontextmanager from pathlib import Path +from typing import TypedDict import anyio.from_thread import pytest @@ -95,6 +96,19 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER) +class CustomState(TypedDict): + count: int + + +@asynccontextmanager +async def lifespan(app: Starlette) -> AsyncGenerator[CustomState]: + yield {"count": 1} + + +async def state_count(request: Request[CustomState]) -> JSONResponse: + return JSONResponse({"count": request.state["count"]}, status_code=200) + + users = Router( routes=[ Route("/", endpoint=all_users_page), @@ -122,6 +136,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> Route("/func", endpoint=func_homepage), Route("/async", endpoint=async_homepage), Route("/class", endpoint=Homepage), + Route("/state", endpoint=state_count), Route("/500", endpoint=runtime_error), WebSocketRoute("/ws", endpoint=websocket_endpoint), WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception), @@ -132,6 +147,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> ], exception_handlers=exception_handlers, # type: ignore middleware=middleware, + lifespan=lifespan, ) @@ -216,6 +232,12 @@ def test_500(test_client_factory: TestClientFactory) -> None: assert response.json() == {"detail": "Server Error"} +def test_request_state(client: TestClient) -> None: + response = client.get("/state") + assert response.status_code == 200 + assert response.json() == {"count": 1} + + def test_websocket_raise_websocket_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-websocket") as session: response = session.receive() @@ -256,6 +278,7 @@ def test_routes() -> None: Route("/func", endpoint=func_homepage, methods=["GET"]), Route("/async", endpoint=async_homepage, methods=["GET"]), Route("/class", endpoint=Homepage), + Route("/state", endpoint=state_count, methods=["GET"]), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception), diff --git a/tests/test_requests.py b/tests/test_requests.py index 1799e89e0..0a28a00e4 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -302,6 +302,28 @@ def test_request_state_object() -> None: with pytest.raises(AttributeError): s.new + # Test dictionary-style methods + # Test __setitem__ + s["dict_key"] = "dict_value" + assert s["dict_key"] == "dict_value" + assert s.dict_key == "dict_value" + + # Test __iter__ + s["another_key"] = "another_value" + keys = list(s) + assert "old" in keys + assert "dict_key" in keys + assert "another_key" in keys + + # Test __len__ + assert len(s) == 3 + + # Test __delitem__ + del s["dict_key"] + assert len(s) == 2 + with pytest.raises(KeyError): + s["dict_key"] + def test_request_state(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: