diff --git a/README.md b/README.md index ae9996848..2c27b5918 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,6 @@ Starlette only requires `anyio`, and the following are optional: * [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. -* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip install starlette[full]`. @@ -135,7 +134,6 @@ in isolation. [httpx]: https://www.python-httpx.org/ [jinja2]: https://jinja.palletsprojects.com/ [python-multipart]: https://multipart.fastapiexpert.com/ -[itsdangerous]: https://itsdangerous.palletsprojects.com/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation [techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf diff --git a/docs/index.md b/docs/index.md index eeb11c0ea..a7415d243 100644 --- a/docs/index.md +++ b/docs/index.md @@ -109,7 +109,6 @@ Starlette only requires `anyio`, and the following dependencies are optional: * [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. -* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip install starlette[full]`. @@ -156,7 +155,6 @@ in isolation. [httpx]: https://www.python-httpx.org/ [jinja2]: https://jinja.palletsprojects.com/ [python-multipart]: https://multipart.fastapiexpert.com/ -[itsdangerous]: https://itsdangerous.palletsprojects.com/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation [techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf diff --git a/pyproject.toml b/pyproject.toml index e7799cdcd..dcf11511c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dependencies = [ [project.optional-dependencies] full = [ - "itsdangerous", "jinja2", "python-multipart>=0.0.18", "pyyaml", diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 1b95db4b0..918ca354f 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -1,14 +1,12 @@ from __future__ import annotations import json -from base64 import b64decode, b64encode +import time from typing import Literal -import itsdangerous -from itsdangerous.exc import BadSignature - from starlette.datastructures import MutableHeaders, Secret from starlette.requests import HTTPConnection +from starlette.signing import TimestampSigner from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -25,7 +23,7 @@ def __init__( domain: str | None = None, ) -> None: self.app = app - self.signer = itsdangerous.TimestampSigner(str(secret_key)) + self.signer = TimestampSigner(secret_key) self.session_cookie = session_cookie self.max_age = max_age self.path = path @@ -41,35 +39,44 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return connection = HTTPConnection(scope) - initial_session_was_empty = True + initial_payload: bytes | None = None + initial_timestamp: int | None = None if self.session_cookie in connection.cookies: data = connection.cookies[self.session_cookie].encode("utf-8") - try: - data = self.signer.unsign(data, max_age=self.max_age) - scope["session"] = json.loads(b64decode(data)) - initial_session_was_empty = False - except BadSignature: - scope["session"] = {} - else: + result = self.signer.unsign(data, max_age=self.max_age) + if result is not None: + initial_payload, initial_timestamp = result + scope["session"] = json.loads(initial_payload) + + if initial_payload is None: scope["session"] = {} async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": if scope["session"]: # We have session data to persist. - data = b64encode(json.dumps(scope["session"]).encode("utf-8")) - data = self.signer.sign(data) - headers = MutableHeaders(scope=message) - header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( - session_cookie=self.session_cookie, - data=data.decode("utf-8"), - path=self.path, - max_age=f"Max-Age={self.max_age}; " if self.max_age else "", - security_flags=self.security_flags, - ) - headers.append("Set-Cookie", header_value) - elif not initial_session_was_empty: + current_payload = json.dumps(scope["session"]).encode("utf-8") + needs_cookie = False + + if current_payload != initial_payload: + needs_cookie = True + elif self.max_age is not None and initial_timestamp is not None: + if time.time() - initial_timestamp > self.max_age / 4: + needs_cookie = True + + if needs_cookie: + data = self.signer.sign(current_payload) + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( + session_cookie=self.session_cookie, + data=data.decode("utf-8"), + path=self.path, + max_age=f"Max-Age={self.max_age}; " if self.max_age is not None else "", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + elif initial_payload is not None: # The session has been cleared. headers = MutableHeaders(scope=message) header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( diff --git a/starlette/signing.py b/starlette/signing.py new file mode 100644 index 000000000..d7e304cb2 --- /dev/null +++ b/starlette/signing.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import time + +from starlette.datastructures import Secret + + +class TimestampSigner: + """ + Signs and unsigns data with HMAC-SHA256 and timestamp validation. + + Format: _ + """ + + def __init__(self, secret: str | Secret) -> None: + """ + Initialize signer with a secret key. + + Args: + secret: Secret key for HMAC signing + """ + self._secret = str(secret).encode("utf-8") + + def sign(self, data: bytes) -> bytes: + """ + Sign data with current timestamp. + + Args: + data: Raw data to sign + + Returns: + Signed token as bytes + """ + timestamp_bytes = int(time.time()).to_bytes(5, "big") + + combined = data + timestamp_bytes + combined_encoded = _b64_encode(combined) + + signature = hmac.HMAC(self._secret, combined_encoded, hashlib.sha256).digest()[:16] + signature_encoded = _b64_encode(signature) + + return combined_encoded + (signature_encoded + b"_") + + def unsign(self, signed_data: bytes, max_age: int | None = None) -> tuple[bytes, int] | None: + """ + Verify and extract data from signed token. + + Args: + signed_data: Signed token + max_age: Maximum age in seconds (optional) + + Returns: + Tuple of (payload bytes, timestamp) on success, None on failure + """ + # Quick pre-checks + if len(signed_data) < 30 or signed_data[-1] != 95: + return None + + signature_encoded = signed_data[-23:-1] + signature = _b64_decode(signature_encoded) + if signature is None or len(signature) != 16: + return None + + combined_encoded = signed_data[:-23] + expected_signature = hmac.HMAC(self._secret, combined_encoded, hashlib.sha256).digest()[:16] + if not hmac.compare_digest(signature, expected_signature): + return None + + combined = _b64_decode(combined_encoded) + if combined is None: # pragma: no cover + return None + + timestamp_bytes = combined[-5:] + timestamp = int.from_bytes(timestamp_bytes, "big") + + # Check timestamp age if max_age is set + if max_age is not None: + if time.time() - timestamp > max_age: + return None + + data = combined[:-5] + return (data, timestamp) + + +def _b64_encode(data: bytes) -> bytes: + """Encode bytes to base64url format without padding.""" + return base64.urlsafe_b64encode(data).rstrip(b"=") + + +def _b64_decode(data: bytes) -> bytes | None: + """Decode base64url format, adding padding if needed. Returns None on error.""" + # Add padding if needed + padding = 4 - (len(data) % 4) + if padding != 4: + data = data + b"=" * padding + + try: + return base64.urlsafe_b64decode(data) + except Exception: + return None diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index b4f3c64fa..4e2ff3942 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -1,4 +1,5 @@ import re +from unittest import mock from starlette.applications import Starlette from starlette.middleware import Middleware @@ -198,3 +199,87 @@ def test_domain_cookie(test_client_factory: TestClientFactory) -> None: client.cookies.delete("session") response = client.get("/view_session") assert response.json() == {"session": {}} + + +def test_session_no_cookie_when_unchanged(test_client_factory: TestClientFactory) -> None: + app = Starlette( + routes=[ + Route("/view_session", endpoint=view_session), + Route("/update_session", endpoint=update_session, methods=["POST"]), + ], + middleware=[Middleware(SessionMiddleware, secret_key="example")], + ) + client = test_client_factory(app) + + response = client.post("/update_session", json={"some": "data"}) + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" in response.headers + + response = client.get("/view_session") + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" not in response.headers + + +def test_session_cookie_when_modified(test_client_factory: TestClientFactory) -> None: + app = Starlette( + routes=[ + Route("/view_session", endpoint=view_session), + Route("/update_session", endpoint=update_session, methods=["POST"]), + ], + middleware=[Middleware(SessionMiddleware, secret_key="example")], + ) + client = test_client_factory(app) + + response = client.post("/update_session", json={"some": "data"}) + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" in response.headers + + response = client.post("/update_session", json={"some": "data", "more": "values"}) + assert response.json() == {"session": {"some": "data", "more": "values"}} + assert "set-cookie" in response.headers + + +def test_session_cookie_refresh_when_stale(test_client_factory: TestClientFactory) -> None: + max_age = 60 + app = Starlette( + routes=[ + Route("/view_session", endpoint=view_session), + Route("/update_session", endpoint=update_session, methods=["POST"]), + ], + middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=max_age)], + ) + client = test_client_factory(app) + + with mock.patch("time.time", return_value=1000): + response = client.post("/update_session", json={"some": "data"}) + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" in response.headers + + with mock.patch("time.time", return_value=1000): + response = client.get("/view_session") + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" not in response.headers + + with mock.patch("time.time", return_value=1000 + max_age / 4 + 1): + response = client.get("/view_session") + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" in response.headers + + +def test_session_no_cookie_when_max_age_none(test_client_factory: TestClientFactory) -> None: + app = Starlette( + routes=[ + Route("/view_session", endpoint=view_session), + Route("/update_session", endpoint=update_session, methods=["POST"]), + ], + middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)], + ) + client = test_client_factory(app) + + response = client.post("/update_session", json={"some": "data"}) + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" in response.headers + + response = client.get("/view_session") + assert response.json() == {"session": {"some": "data"}} + assert "set-cookie" not in response.headers diff --git a/tests/test_signing.py b/tests/test_signing.py new file mode 100644 index 000000000..120487d44 --- /dev/null +++ b/tests/test_signing.py @@ -0,0 +1,165 @@ +from unittest import mock + +from starlette.signing import TimestampSigner, _b64_decode, _b64_encode + + +class TestTimestampSigner: + def test_sign_basic(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"hello") + assert isinstance(signed, bytes) + assert len(signed) > 30 + + def test_round_trip(self) -> None: + signer = TimestampSigner("secret") + test_cases = [ + b"", + b"a", + b"hello world", + b"\x00\x01\x02\xff\xfe\xfd", + b"\x00\x00\x00\x00", + b"x" * 1000, + b'{"user": "alice", "id": 123}', + "Hello δΈ–η•Œ 🌍".encode(), + ] + for data in test_cases: + signed = signer.sign(data) + result = signer.unsign(signed) + assert result is not None + payload, timestamp = result + assert payload == data + assert isinstance(timestamp, int) + + def test_output_is_ascii(self) -> None: + signer = TimestampSigner("secret") + data = "Hello δΈ–η•Œ 🌍".encode() + signed = signer.sign(data) + signed.decode("ascii") + assert all(b < 128 for b in signed) + + def test_no_padding_in_output(self) -> None: + signer = TimestampSigner("secret") + for size in range(10): + data = b"x" * size + signed = signer.sign(data) + assert b"=" not in signed + + def test_no_forbidden_characters(self) -> None: + signer = TimestampSigner("secret") + forbidden = set(b' ,;"\\') + test_data = [ + b"simple", + b'{"user": "alice", "roles": ["admin"]}', + b"\x00\xff" * 10, + b"x" * 10, + ] + for data in test_data: + signed = signer.sign(data) + assert not any(c in forbidden for c in signed) + valid_chars = set(b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in signed) + + +class TestTimestampValidation: + def test_unsign_with_valid_max_age(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"data") + result = signer.unsign(signed, max_age=10) + assert result is not None + assert result[0] == b"data" + result = signer.unsign(signed, max_age=1000) + assert result is not None + assert result[0] == b"data" + + def test_unsign_with_expired_max_age(self) -> None: + signer = TimestampSigner("secret") + with mock.patch("time.time", return_value=1000): + signed = signer.sign(b"data") + + with mock.patch("time.time", return_value=1003): + assert signer.unsign(signed, max_age=1) is None + assert signer.unsign(signed, max_age=0) is None + + result = signer.unsign(signed, max_age=None) + assert result is not None + assert result[0] == b"data" + + +class TestSecurityAttacks: + def test_tampered_signature(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"data") + + tampered_byte = bytes([signed[-23] ^ 0xFF]) + tampered = signed[:-23] + tampered_byte + signed[-22:] + assert signer.unsign(tampered) is None + + tampered_byte = bytes([signed[-12] ^ 0xFF]) + tampered = signed[:-12] + tampered_byte + signed[-11:] + assert signer.unsign(tampered) is None + + def test_tampered_payload(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"data") + tampered = b"XXXX" + signed[4:] + assert signer.unsign(tampered) is None + + def test_signature_from_different_payload(self) -> None: + signer = TimestampSigner("secret") + signed1 = signer.sign(b"data1") + signed2 = signer.sign(b"data2") + + combined1 = signed1[:-23] + suffix2 = signed2[-23:] + mixed = combined1 + suffix2 + assert signer.unsign(mixed) is None + + +class TestMalformedData: + def test_minimum_length_requirement(self) -> None: + signer = TimestampSigner("secret") + assert signer.unsign(b"") is None + assert signer.unsign(b"a") is None + assert signer.unsign(b"a" * 20) is None + assert signer.unsign(b"a" * 29) is None + assert signer.unsign(b"short_") is None + assert signer.unsign(b"a" * 29 + b"_") is None + + def test_missing_version_marker(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"data") + + without_marker = signed[:-1] + assert signer.unsign(without_marker) is None + + wrong_marker = signed[:-1] + b"X" + assert signer.unsign(wrong_marker) is None + + def test_invalid_base64_in_combined_data(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"data") + tampered = b"!!!!!!AAAAA" + signed[-23:] + assert signer.unsign(tampered) is None + + def test_invalid_base64_in_signature(self) -> None: + signer = TimestampSigner("secret") + signed = signer.sign(b"data") + tampered = signed[:-23] + b"!" * 22 + b"_" + assert signer.unsign(tampered) is None + + +class TestBase64UrlHelpers: + def test_encode_basic(self) -> None: + assert _b64_encode(b"hello") == b"aGVsbG8" + assert _b64_encode(b"") == b"" + + def test_decode_invalid_returns_none(self) -> None: + assert _b64_decode(b"A") is None + assert _b64_decode(b"AAAAA") is None + + def test_round_trip(self) -> None: + test_cases = [b"", b"a", b"hello world", b"\x00\x01\x02\xff\xfe\xfd"] + for data in test_cases: + encoded = _b64_encode(data) + decoded = _b64_decode(encoded) + assert decoded == data diff --git a/uv.lock b/uv.lock index 13a0ebd0f..3634687c2 100644 --- a/uv.lock +++ b/uv.lock @@ -460,15 +460,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] -[[package]] -name = "itsdangerous" -version = "2.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, -] - [[package]] name = "jaraco-classes" version = "3.4.0" @@ -1177,7 +1168,6 @@ dependencies = [ [package.optional-dependencies] full = [ { name = "httpx" }, - { name = "itsdangerous" }, { name = "jinja2" }, { name = "python-multipart" }, { name = "pyyaml" }, @@ -1206,7 +1196,6 @@ docs = [ requires-dist = [ { name = "anyio", specifier = ">=3.6.2,<5" }, { name = "httpx", marker = "extra == 'full'", specifier = ">=0.27.0,<0.29.0" }, - { name = "itsdangerous", marker = "extra == 'full'" }, { name = "jinja2", marker = "extra == 'full'" }, { name = "python-multipart", marker = "extra == 'full'", specifier = ">=0.0.18" }, { name = "pyyaml", marker = "extra == 'full'" },