Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions docs/lifespan.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,58 @@ app = Starlette(
The `state` received on the requests is a **shallow** copy of the state received on the
lifespan handler.

## Accessing State

The state can be accessed using either attribute-style or dictionary-style syntax.

The dictionary-style syntax was introduced in Starlette 0.52.0 (January 2026), with the idea of
improving type safety when using the lifespan state, given that `Request` became a generic over
the state type.

```python
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TypedDict

import httpx
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route


class State(TypedDict):
http_client: httpx.AsyncClient


@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[State]:
async with httpx.AsyncClient() as client:
yield {"http_client": client}


async def homepage(request: Request[State]) -> PlainTextResponse:
client = request.state["http_client"]

reveal_type(client) # Revealed type is 'httpx.AsyncClient'

response = await client.get("https://www.example.com")
return PlainTextResponse(response.text)

app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)])
```

!!! note
There were many attempts to make this work with attribute-style access instead of
dictionary-style access, but none were satisfactory, given they would have been
breaking changes, or there were typing limitations.

For more details, see:

- [@Kludex/starlette#issues/3005](https://github.com/Kludex/starlette/issues/3005)
- [@python/typing#discussions/1457](https://github.com/python/typing/discussions/1457)
- [@Kludex/starlette#pull/3036](https://github.com/Kludex/starlette/pull/3036)

## Running lifespan in tests

You should use `TestClient` as a context manager, to ensure that the lifespan is called.
Expand Down
15 changes: 15 additions & 0 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,3 +689,18 @@ def __getattr__(self, key: Any) -> Any:

def __delattr__(self, key: Any) -> None:
del self._state[key]

def __getitem__(self, key: str) -> Any:
return self._state[key]

def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value

def __delitem__(self, key: str) -> None:
del self._state[key]

def __iter__(self) -> Iterator[str]:
return iter(self._state)

def __len__(self) -> int:
return len(self._state)
14 changes: 9 additions & 5 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import json
from collections.abc import AsyncGenerator, Iterator, Mapping
from http import cookies as http_cookies
from typing import TYPE_CHECKING, Any, NoReturn, cast
from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast

import anyio
from typing_extensions import TypeVar

from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
Expand Down Expand Up @@ -68,7 +69,10 @@ class ClientDisconnect(Exception):
pass


class HTTPConnection(Mapping[str, Any]):
StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will pass the TypedDict check in some static type checkers.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which one you don't think it will not? We test with mypy here, but pyright doesn't seem to be failing (I use it on my environment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to be able to pass the inspection.



class HTTPConnection(Mapping[str, Any], Generic[StateT]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
Expand Down Expand Up @@ -172,14 +176,14 @@ def user(self) -> Any:
return self.scope["user"]

@property
def state(self) -> State:
def state(self) -> StateT:
if not hasattr(self, "_state"):
# Ensure 'state' has an empty dict if it's not already populated.
self.scope.setdefault("state", {})
# Create a state instance with a reference to the dict in which it should
# store info
self._state = State(self.scope["state"])
return self._state
return cast(StateT, self._state)

def url_for(self, name: str, /, **path_params: Any) -> URL:
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
Expand All @@ -197,7 +201,7 @@ async def empty_send(message: Message) -> NoReturn:
raise RuntimeError("Send channel has not been made available")


class Request(HTTPConnection):
class Request(HTTPConnection[StateT]):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are missing the same in WebSockets.

_form: FormData | None

def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TypedDict

import anyio.from_thread
import pytest
Expand Down Expand Up @@ -95,6 +96,19 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)


class CustomState(TypedDict):
count: int


@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[CustomState]:
yield {"count": 1}


async def state_count(request: Request[CustomState]) -> JSONResponse:
return JSONResponse({"count": request.state["count"]}, status_code=200)


users = Router(
routes=[
Route("/", endpoint=all_users_page),
Expand Down Expand Up @@ -122,6 +136,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->
Route("/func", endpoint=func_homepage),
Route("/async", endpoint=async_homepage),
Route("/class", endpoint=Homepage),
Route("/state", endpoint=state_count),
Route("/500", endpoint=runtime_error),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
Expand All @@ -132,6 +147,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->
],
exception_handlers=exception_handlers, # type: ignore
middleware=middleware,
lifespan=lifespan,
)


Expand Down Expand Up @@ -216,6 +232,12 @@ def test_500(test_client_factory: TestClientFactory) -> None:
assert response.json() == {"detail": "Server Error"}


def test_request_state(client: TestClient) -> None:
response = client.get("/state")
assert response.status_code == 200
assert response.json() == {"count": 1}


def test_websocket_raise_websocket_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
Expand Down Expand Up @@ -256,6 +278,7 @@ def test_routes() -> None:
Route("/func", endpoint=func_homepage, methods=["GET"]),
Route("/async", endpoint=async_homepage, methods=["GET"]),
Route("/class", endpoint=Homepage),
Route("/state", endpoint=state_count, methods=["GET"]),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
Expand Down
22 changes: 22 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,28 @@ def test_request_state_object() -> None:
with pytest.raises(AttributeError):
s.new

# Test dictionary-style methods
# Test __setitem__
s["dict_key"] = "dict_value"
assert s["dict_key"] == "dict_value"
assert s.dict_key == "dict_value"

# Test __iter__
s["another_key"] = "another_value"
keys = list(s)
assert "old" in keys
assert "dict_key" in keys
assert "another_key" in keys

# Test __len__
assert len(s) == 3

# Test __delitem__
del s["dict_key"]
assert len(s) == 2
with pytest.raises(KeyError):
s["dict_key"]


def test_request_state(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
Expand Down