Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ Starlette only requires `anyio`, and the following are optional:
* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
* [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support.

You can install all of these with `pip install starlette[full]`.
Expand Down Expand Up @@ -135,7 +134,6 @@ in isolation.
[httpx]: https://www.python-httpx.org/
[jinja2]: https://jinja.palletsprojects.com/
[python-multipart]: https://multipart.fastapiexpert.com/
[itsdangerous]: https://itsdangerous.palletsprojects.com/
[sqlalchemy]: https://www.sqlalchemy.org
[pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation
[techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf
2 changes: 0 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ Starlette only requires `anyio`, and the following dependencies are optional:
* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
* [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support.

You can install all of these with `pip install starlette[full]`.
Expand Down Expand Up @@ -156,7 +155,6 @@ in isolation.
[httpx]: https://www.python-httpx.org/
[jinja2]: https://jinja.palletsprojects.com/
[python-multipart]: https://multipart.fastapiexpert.com/
[itsdangerous]: https://itsdangerous.palletsprojects.com/
[sqlalchemy]: https://www.sqlalchemy.org
[pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation
[techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies = [

[project.optional-dependencies]
full = [
"itsdangerous",
"jinja2",
"python-multipart>=0.0.18",
"pyyaml",
Expand Down
57 changes: 32 additions & 25 deletions starlette/middleware/sessions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import json
from base64 import b64decode, b64encode
import time
from typing import Literal

import itsdangerous
from itsdangerous.exc import BadSignature

from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.signing import TimestampSigner
from starlette.types import ASGIApp, Message, Receive, Scope, Send


Expand All @@ -25,7 +23,7 @@ def __init__(
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
self.signer = TimestampSigner(secret_key)
self.session_cookie = session_cookie
self.max_age = max_age
self.path = path
Expand All @@ -41,35 +39,44 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return

connection = HTTPConnection(scope)
initial_session_was_empty = True
initial_payload: bytes | None = None
initial_timestamp: int | None = None

if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
initial_session_was_empty = False
except BadSignature:
scope["session"] = {}
else:
result = self.signer.unsign(data, max_age=self.max_age)
if result is not None:
initial_payload, initial_timestamp = result
scope["session"] = json.loads(initial_payload)

if initial_payload is None:
scope["session"] = {}

async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
if scope["session"]:
# We have session data to persist.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif not initial_session_was_empty:
current_payload = json.dumps(scope["session"]).encode("utf-8")
needs_cookie = False

if current_payload != initial_payload:
needs_cookie = True
elif self.max_age is not None and initial_timestamp is not None:
if time.time() - initial_timestamp > self.max_age / 4:
needs_cookie = True

if needs_cookie:
data = self.signer.sign(current_payload)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
max_age=f"Max-Age={self.max_age}; " if self.max_age is not None else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif initial_payload is not None:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
Expand Down
103 changes: 103 additions & 0 deletions starlette/signing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import base64
import hashlib
import hmac
import time

from starlette.datastructures import Secret


class TimestampSigner:
"""
Signs and unsigns data with HMAC-SHA256 and timestamp validation.

Format: <b64(payload+timestamp)><b64(signature)>_
"""

def __init__(self, secret: str | Secret) -> None:
"""
Initialize signer with a secret key.

Args:
secret: Secret key for HMAC signing
"""
self._secret = str(secret).encode("utf-8")

def sign(self, data: bytes) -> bytes:
"""
Sign data with current timestamp.

Args:
data: Raw data to sign

Returns:
Signed token as bytes
"""
timestamp_bytes = int(time.time()).to_bytes(5, "big")

combined = data + timestamp_bytes
combined_encoded = _b64_encode(combined)

signature = hmac.HMAC(self._secret, combined_encoded, hashlib.sha256).digest()[:16]
signature_encoded = _b64_encode(signature)

return combined_encoded + (signature_encoded + b"_")

def unsign(self, signed_data: bytes, max_age: int | None = None) -> tuple[bytes, int] | None:
"""
Verify and extract data from signed token.

Args:
signed_data: Signed token
max_age: Maximum age in seconds (optional)

Returns:
Tuple of (payload bytes, timestamp) on success, None on failure
"""
# Quick pre-checks
if len(signed_data) < 30 or signed_data[-1] != 95:
return None

signature_encoded = signed_data[-23:-1]
signature = _b64_decode(signature_encoded)
if signature is None or len(signature) != 16:
return None

combined_encoded = signed_data[:-23]
expected_signature = hmac.HMAC(self._secret, combined_encoded, hashlib.sha256).digest()[:16]
if not hmac.compare_digest(signature, expected_signature):
return None

combined = _b64_decode(combined_encoded)
if combined is None: # pragma: no cover
return None

timestamp_bytes = combined[-5:]
timestamp = int.from_bytes(timestamp_bytes, "big")

# Check timestamp age if max_age is set
if max_age is not None:
if time.time() - timestamp > max_age:
return None

data = combined[:-5]
return (data, timestamp)


def _b64_encode(data: bytes) -> bytes:
"""Encode bytes to base64url format without padding."""
return base64.urlsafe_b64encode(data).rstrip(b"=")


def _b64_decode(data: bytes) -> bytes | None:
"""Decode base64url format, adding padding if needed. Returns None on error."""
# Add padding if needed
padding = 4 - (len(data) % 4)
if padding != 4:
data = data + b"=" * padding

try:
return base64.urlsafe_b64decode(data)
except Exception:
return None
85 changes: 85 additions & 0 deletions tests/middleware/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from unittest import mock

from starlette.applications import Starlette
from starlette.middleware import Middleware
Expand Down Expand Up @@ -198,3 +199,87 @@ def test_domain_cookie(test_client_factory: TestClientFactory) -> None:
client.cookies.delete("session")
response = client.get("/view_session")
assert response.json() == {"session": {}}


def test_session_no_cookie_when_unchanged(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example")],
)
client = test_client_factory(app)

response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" in response.headers

response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" not in response.headers


def test_session_cookie_when_modified(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example")],
)
client = test_client_factory(app)

response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" in response.headers

response = client.post("/update_session", json={"some": "data", "more": "values"})
assert response.json() == {"session": {"some": "data", "more": "values"}}
assert "set-cookie" in response.headers


def test_session_cookie_refresh_when_stale(test_client_factory: TestClientFactory) -> None:
max_age = 60
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=max_age)],
)
client = test_client_factory(app)

with mock.patch("time.time", return_value=1000):
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" in response.headers

with mock.patch("time.time", return_value=1000):
response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" not in response.headers

with mock.patch("time.time", return_value=1000 + max_age / 4 + 1):
response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" in response.headers


def test_session_no_cookie_when_max_age_none(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)],
)
client = test_client_factory(app)

response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" in response.headers

response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" not in response.headers
Loading