Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ Starlette only requires `anyio`, and the following are optional:
* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
* [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support.

You can install all of these with `pip install starlette[full]`.
Expand Down Expand Up @@ -135,7 +134,6 @@ in isolation.
[httpx]: https://www.python-httpx.org/
[jinja2]: https://jinja.palletsprojects.com/
[python-multipart]: https://multipart.fastapiexpert.com/
[itsdangerous]: https://itsdangerous.palletsprojects.com/
[sqlalchemy]: https://www.sqlalchemy.org
[pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation
[techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf
2 changes: 0 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ Starlette only requires `anyio`, and the following dependencies are optional:
* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
* [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support.

You can install all of these with `pip install starlette[full]`.
Expand Down Expand Up @@ -156,7 +155,6 @@ in isolation.
[httpx]: https://www.python-httpx.org/
[jinja2]: https://jinja.palletsprojects.com/
[python-multipart]: https://multipart.fastapiexpert.com/
[itsdangerous]: https://itsdangerous.palletsprojects.com/
[sqlalchemy]: https://www.sqlalchemy.org
[pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation
[techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies = [

[project.optional-dependencies]
full = [
"itsdangerous",
"jinja2",
"python-multipart>=0.0.18",
"pyyaml",
Expand Down
17 changes: 7 additions & 10 deletions starlette/middleware/sessions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from __future__ import annotations

import json
from base64 import b64decode, b64encode
from typing import Literal

import itsdangerous
from itsdangerous.exc import BadSignature

from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.signing import TimestampSigner
from starlette.types import ASGIApp, Message, Receive, Scope, Send


Expand All @@ -25,7 +22,7 @@ def __init__(
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
self.signer = TimestampSigner(secret_key)
self.session_cookie = session_cookie
self.max_age = max_age
self.path = path
Expand All @@ -45,11 +42,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
payload = self.signer.unsign(data, max_age=self.max_age)
if payload is not None:
scope["session"] = json.loads(payload)
initial_session_was_empty = False
except BadSignature:
else:
scope["session"] = {}
else:
scope["session"] = {}
Expand All @@ -58,7 +55,7 @@ async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
if scope["session"]:
# We have session data to persist.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = json.dumps(scope["session"]).encode("utf-8")
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
Expand Down
102 changes: 102 additions & 0 deletions starlette/signing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations

import base64
import hashlib
import hmac
import time

from starlette.datastructures import Secret


class TimestampSigner:
"""
Signs and unsigns data with HMAC-SHA256 and timestamp validation.

Format: <b64(payload+timestamp)><b64(signature)>_
"""

def __init__(self, secret: str | Secret) -> None:
"""
Initialize signer with a secret key.

Args:
secret: Secret key for HMAC signing
"""
self._secret = str(secret).encode("utf-8")

def sign(self, data: bytes) -> bytes:
"""
Sign data with current timestamp.

Args:
data: Raw data to sign

Returns:
Signed token as bytes
"""
timestamp_bytes = int(time.time()).to_bytes(5, "big")

combined = data + timestamp_bytes
combined_encoded = _b64_encode(combined)

signature = hmac.HMAC(self._secret, combined_encoded, hashlib.sha256).digest()[:16]
signature_encoded = _b64_encode(signature)

return combined_encoded + (signature_encoded + b"_")

def unsign(self, signed_data: bytes, max_age: int | None = None) -> bytes | None:
"""
Verify and extract data from signed token.

Args:
signed_data: Signed token
max_age: Maximum age in seconds (optional)

Returns:
Payload bytes on success, None on failure
"""
# Quick pre-checks
if len(signed_data) < 30 or signed_data[-1] != 95:
return None

signature_encoded = signed_data[-23:-1]
signature = _b64_decode(signature_encoded)
if signature is None or len(signature) != 16:
return None

combined_encoded = signed_data[:-23]
expected_signature = hmac.HMAC(self._secret, combined_encoded, hashlib.sha256).digest()[:16]
if not hmac.compare_digest(signature, expected_signature):
return None

combined = _b64_decode(combined_encoded)
if combined is None: # pragma: no cover
return None

# Check timestamp age if max_age is set
if max_age is not None:
timestamp_bytes = combined[-5:]
timestamp = int.from_bytes(timestamp_bytes, "big")
if time.time() - timestamp > max_age:
return None

data = combined[:-5]
return data


def _b64_encode(data: bytes) -> bytes:
"""Encode bytes to base64url format without padding."""
return base64.urlsafe_b64encode(data).rstrip(b"=")


def _b64_decode(data: bytes) -> bytes | None:
"""Decode base64url format, adding padding if needed. Returns None on error."""
# Add padding if needed
padding = 4 - (len(data) % 4)
if padding != 4:
data = data + b"=" * padding

try:
return base64.urlsafe_b64decode(data)
except Exception:
return None
156 changes: 156 additions & 0 deletions tests/test_signing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from unittest import mock

from starlette.signing import TimestampSigner, _b64_decode, _b64_encode


class TestTimestampSigner:
def test_sign_basic(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"hello")
assert isinstance(signed, bytes)
assert len(signed) > 30

def test_round_trip(self) -> None:
signer = TimestampSigner("secret")
test_cases = [
b"",
b"a",
b"hello world",
b"\x00\x01\x02\xff\xfe\xfd",
b"\x00\x00\x00\x00",
b"x" * 1000,
b'{"user": "alice", "id": 123}',
"Hello 世界 🌍".encode(),
]
for data in test_cases:
signed = signer.sign(data)
unsigned = signer.unsign(signed)
assert unsigned == data

def test_output_is_ascii(self) -> None:
signer = TimestampSigner("secret")
data = "Hello 世界 🌍".encode()
signed = signer.sign(data)
signed.decode("ascii")
assert all(b < 128 for b in signed)

def test_no_padding_in_output(self) -> None:
signer = TimestampSigner("secret")
for size in range(10):
data = b"x" * size
signed = signer.sign(data)
assert b"=" not in signed

def test_no_forbidden_characters(self) -> None:
signer = TimestampSigner("secret")
forbidden = set(b' ,;"\\')
test_data = [
b"simple",
b'{"user": "alice", "roles": ["admin"]}',
b"\x00\xff" * 10,
b"x" * 10,
]
for data in test_data:
signed = signer.sign(data)
assert not any(c in forbidden for c in signed)
valid_chars = set(b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
assert all(c in valid_chars for c in signed)


class TestTimestampValidation:
def test_unsign_with_valid_max_age(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"data")
assert signer.unsign(signed, max_age=10) == b"data"
assert signer.unsign(signed, max_age=1000) == b"data"

def test_unsign_with_expired_max_age(self) -> None:
signer = TimestampSigner("secret")
with mock.patch("time.time", return_value=1000):
signed = signer.sign(b"data")

with mock.patch("time.time", return_value=1003):
assert signer.unsign(signed, max_age=1) is None
assert signer.unsign(signed, max_age=0) is None

assert signer.unsign(signed, max_age=None) == b"data"


class TestSecurityAttacks:
def test_tampered_signature(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"data")

tampered_byte = bytes([signed[-23] ^ 0xFF])
tampered = signed[:-23] + tampered_byte + signed[-22:]
assert signer.unsign(tampered) is None

tampered_byte = bytes([signed[-12] ^ 0xFF])
tampered = signed[:-12] + tampered_byte + signed[-11:]
assert signer.unsign(tampered) is None

def test_tampered_payload(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"data")
tampered = b"XXXX" + signed[4:]
assert signer.unsign(tampered) is None

def test_signature_from_different_payload(self) -> None:
signer = TimestampSigner("secret")
signed1 = signer.sign(b"data1")
signed2 = signer.sign(b"data2")

combined1 = signed1[:-23]
suffix2 = signed2[-23:]
mixed = combined1 + suffix2
assert signer.unsign(mixed) is None


class TestMalformedData:
def test_minimum_length_requirement(self) -> None:
signer = TimestampSigner("secret")
assert signer.unsign(b"") is None
assert signer.unsign(b"a") is None
assert signer.unsign(b"a" * 20) is None
assert signer.unsign(b"a" * 29) is None
assert signer.unsign(b"short_") is None
assert signer.unsign(b"a" * 29 + b"_") is None

def test_missing_version_marker(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"data")

without_marker = signed[:-1]
assert signer.unsign(without_marker) is None

wrong_marker = signed[:-1] + b"X"
assert signer.unsign(wrong_marker) is None

def test_invalid_base64_in_combined_data(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"data")
tampered = b"!!!!!!AAAAA" + signed[-23:]
assert signer.unsign(tampered) is None

def test_invalid_base64_in_signature(self) -> None:
signer = TimestampSigner("secret")
signed = signer.sign(b"data")
tampered = signed[:-23] + b"!" * 22 + b"_"
assert signer.unsign(tampered) is None


class TestBase64UrlHelpers:
def test_encode_basic(self) -> None:
assert _b64_encode(b"hello") == b"aGVsbG8"
assert _b64_encode(b"") == b""

def test_decode_invalid_returns_none(self) -> None:
assert _b64_decode(b"A") is None
assert _b64_decode(b"AAAAA") is None

def test_round_trip(self) -> None:
test_cases = [b"", b"a", b"hello world", b"\x00\x01\x02\xff\xfe\xfd"]
for data in test_cases:
encoded = _b64_encode(data)
decoded = _b64_decode(encoded)
assert decoded == data
11 changes: 0 additions & 11 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.