From aab6311a5743dc3624d9888c6c8064277ac8931c Mon Sep 17 00:00:00 2001 From: Dave Gaeddert Date: Wed, 16 Oct 2024 22:27:18 -0500 Subject: [PATCH] Rework middleware and add BUILTIN_MIDDLEWARE --- plain-auth/README.md | 6 +- plain-auth/plain/auth/README.md | 6 +- plain-importmap/test_project/settings.py | 2 - .../tests/provider_tests/test_github.py | 2 +- plain-oauth/tests/test_providers.py | 8 +-- plain-sessions/tests/test_sessions.py | 4 +- plain-staff/README.md | 2 - plain-staff/plain/staff/README.md | 2 - plain-staff/plain/staff/querystats/README.md | 2 - plain/plain/csrf/middleware.py | 59 ++++++----------- plain/plain/internal/handlers/base.py | 14 +++- .../{ => internal}/middleware/__init__.py | 0 plain/plain/internal/middleware/headers.py | 19 ++++++ plain/plain/internal/middleware/https.py | 36 +++++++++++ .../middleware/slash.py} | 60 +---------------- plain/plain/middleware/README.md | 3 - plain/plain/middleware/gzip.py | 64 ------------------- plain/plain/preflight/security/csrf.py | 6 +- plain/plain/runtime/README.md | 2 - plain/plain/runtime/global_settings.py | 8 +-- 20 files changed, 106 insertions(+), 199 deletions(-) rename plain/plain/{ => internal}/middleware/__init__.py (100%) create mode 100644 plain/plain/internal/middleware/headers.py create mode 100644 plain/plain/internal/middleware/https.py rename plain/plain/{middleware/common.py => internal/middleware/slash.py} (51%) delete mode 100644 plain/plain/middleware/README.md delete mode 100644 plain/plain/middleware/gzip.py diff --git a/plain-auth/README.md b/plain-auth/README.md index 7551d8d702..15a032a395 100644 --- a/plain-auth/README.md +++ b/plain-auth/README.md @@ -23,10 +23,8 @@ INSTALLED_PACKAGES = [ ] MIDDLEWARE = [ - "plain.sessions.middleware.SessionMiddleware", # <-- - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", - "plain.auth.middleware.AuthenticationMiddleware", # <-- + "plain.sessions.middleware.SessionMiddleware", + "plain.auth.middleware.AuthenticationMiddleware", ] AUTH_USER_MODEL = "users.User" diff --git a/plain-auth/plain/auth/README.md b/plain-auth/plain/auth/README.md index 2839b587f8..e21e99e13f 100644 --- a/plain-auth/plain/auth/README.md +++ b/plain-auth/plain/auth/README.md @@ -21,10 +21,8 @@ INSTALLED_PACKAGES = [ ] MIDDLEWARE = [ - "plain.sessions.middleware.SessionMiddleware", # <-- - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", - "plain.auth.middleware.AuthenticationMiddleware", # <-- + "plain.sessions.middleware.SessionMiddleware", + "plain.auth.middleware.AuthenticationMiddleware", ] AUTH_USER_MODEL = "users.User" diff --git a/plain-importmap/test_project/settings.py b/plain-importmap/test_project/settings.py index 51b7d8133c..49f1bf51e3 100644 --- a/plain-importmap/test_project/settings.py +++ b/plain-importmap/test_project/settings.py @@ -26,8 +26,6 @@ MIDDLEWARE = [ "plain.sessions.middleware.SessionMiddleware", - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", "plain.auth.middleware.AuthenticationMiddleware", ] diff --git a/plain-oauth/tests/provider_tests/test_github.py b/plain-oauth/tests/provider_tests/test_github.py index 5ab93ca5b1..fa1840a8ef 100644 --- a/plain-oauth/tests/provider_tests/test_github.py +++ b/plain-oauth/tests/provider_tests/test_github.py @@ -39,7 +39,7 @@ def test_github_provider(db, client, settings): assert response.status_code == 302 assert ( response.url - == "https://github.com/login/oauth/authorize?client_id=test_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fgithub%2Fcallback%2F&response_type=code&scope=user&state=dummy_state" + == "https://github.com/login/oauth/authorize?client_id=test_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fgithub%2Fcallback%2F&response_type=code&scope=user&state=dummy_state" ) # GitHub redirects to the callback url diff --git a/plain-oauth/tests/test_providers.py b/plain-oauth/tests/test_providers.py index 2d0029b7a3..2762237dd9 100644 --- a/plain-oauth/tests/test_providers.py +++ b/plain-oauth/tests/test_providers.py @@ -68,7 +68,7 @@ def test_dummy_signup(db, client, settings): assert response.status_code == 302 assert ( response.url - == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" + == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" ) # Provider redirects to the callback url @@ -148,7 +148,7 @@ def test_dummy_login_connection(db, client, settings): assert response.status_code == 302 assert ( response.url - == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" + == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" ) # Provider redirects to the callback url @@ -215,7 +215,7 @@ def test_dummy_login_without_connection(db, client, settings): assert response.status_code == 302 assert ( response.url - == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" + == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" ) # Provider redirects to the callback url @@ -253,7 +253,7 @@ def test_dummy_connect(db, client, settings): assert response.status_code == 302 assert ( response.url - == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" + == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" ) # Provider redirects to the callback url diff --git a/plain-sessions/tests/test_sessions.py b/plain-sessions/tests/test_sessions.py index 2beb54af16..e9dfff60f2 100644 --- a/plain-sessions/tests/test_sessions.py +++ b/plain-sessions/tests/test_sessions.py @@ -4,6 +4,8 @@ def test_session_created(db, client): assert Session.objects.count() == 0 - client.get("/") + response = client.get("/") + + assert response.status_code == 200 assert Session.objects.count() == 1 diff --git a/plain-staff/README.md b/plain-staff/README.md index 8c1b93f85e..515263dd8b 100644 --- a/plain-staff/README.md +++ b/plain-staff/README.md @@ -52,8 +52,6 @@ INSTALLED_PACKAGES = [ MIDDLEWARE = [ "plain.sessions.middleware.SessionMiddleware", - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", "plain.auth.middleware.AuthenticationMiddleware", "plain.staff.querystats.QueryStatsMiddleware", diff --git a/plain-staff/plain/staff/README.md b/plain-staff/plain/staff/README.md index 0d54e31327..9177c17539 100644 --- a/plain-staff/plain/staff/README.md +++ b/plain-staff/plain/staff/README.md @@ -50,8 +50,6 @@ INSTALLED_PACKAGES = [ MIDDLEWARE = [ "plain.sessions.middleware.SessionMiddleware", - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", "plain.auth.middleware.AuthenticationMiddleware", "plain.staff.querystats.QueryStatsMiddleware", diff --git a/plain-staff/plain/staff/querystats/README.md b/plain-staff/plain/staff/querystats/README.md index 60fde5b261..48cea56a1b 100644 --- a/plain-staff/plain/staff/querystats/README.md +++ b/plain-staff/plain/staff/querystats/README.md @@ -27,8 +27,6 @@ INSTALLED_PACKAGES = [ MIDDLEWARE = [ "plain.sessions.middleware.SessionMiddleware", - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", "plain.auth.middleware.AuthenticationMiddleware", "plain.staff.querystats.QueryStatsMiddleware", diff --git a/plain/plain/csrf/middleware.py b/plain/plain/csrf/middleware.py index e0cfbf2976..7a4bde7f8f 100644 --- a/plain/plain/csrf/middleware.py +++ b/plain/plain/csrf/middleware.py @@ -9,7 +9,7 @@ from collections import defaultdict from urllib.parse import urlparse -from plain.exceptions import DisallowedHost, ImproperlyConfigured +from plain.exceptions import DisallowedHost from plain.http import HttpHeaders, UnreadablePostError from plain.logs import log_response from plain.runtime import settings @@ -242,44 +242,31 @@ def _get_secret(self, request): If the CSRF_USE_SESSIONS setting is false, raises InvalidTokenFormat if the request's secret has invalid characters or an invalid length. """ - if settings.CSRF_USE_SESSIONS: - try: - csrf_secret = request.session.get(CSRF_SESSION_KEY) - except AttributeError: - raise ImproperlyConfigured( - "CSRF_USE_SESSIONS is enabled, but request.session is not " - "set. SessionMiddleware must appear before CsrfViewMiddleware " - "in MIDDLEWARE." - ) + try: + csrf_secret = request.COOKIES[settings.CSRF_COOKIE_NAME] + except KeyError: + csrf_secret = None else: - try: - csrf_secret = request.COOKIES[settings.CSRF_COOKIE_NAME] - except KeyError: - csrf_secret = None - else: - # This can raise InvalidTokenFormat. - _check_token_format(csrf_secret) + # This can raise InvalidTokenFormat. + _check_token_format(csrf_secret) + if csrf_secret is None: return None return csrf_secret def _set_csrf_cookie(self, request, response): - if settings.CSRF_USE_SESSIONS: - if request.session.get(CSRF_SESSION_KEY) != request.META["CSRF_COOKIE"]: - request.session[CSRF_SESSION_KEY] = request.META["CSRF_COOKIE"] - else: - response.set_cookie( - settings.CSRF_COOKIE_NAME, - request.META["CSRF_COOKIE"], - max_age=settings.CSRF_COOKIE_AGE, - domain=settings.CSRF_COOKIE_DOMAIN, - path=settings.CSRF_COOKIE_PATH, - secure=settings.CSRF_COOKIE_SECURE, - httponly=settings.CSRF_COOKIE_HTTPONLY, - samesite=settings.CSRF_COOKIE_SAMESITE, - ) - # Set the Vary header since content varies with the CSRF cookie. - patch_vary_headers(response, ("Cookie",)) + response.set_cookie( + settings.CSRF_COOKIE_NAME, + request.META["CSRF_COOKIE"], + max_age=settings.CSRF_COOKIE_AGE, + domain=settings.CSRF_COOKIE_DOMAIN, + path=settings.CSRF_COOKIE_PATH, + secure=settings.CSRF_COOKIE_SECURE, + httponly=settings.CSRF_COOKIE_HTTPONLY, + samesite=settings.CSRF_COOKIE_SAMESITE, + ) + # Set the Vary header since content varies with the CSRF cookie. + patch_vary_headers(response, ("Cookie",)) def _origin_verified(self, request): request_origin = request.META["HTTP_ORIGIN"] @@ -331,11 +318,7 @@ def _check_referer(self, request): ): return # Allow matching the configured cookie domain. - good_referer = ( - settings.SESSION_COOKIE_DOMAIN - if settings.CSRF_USE_SESSIONS - else settings.CSRF_COOKIE_DOMAIN - ) + good_referer = settings.CSRF_COOKIE_DOMAIN if good_referer is None: # If no cookie domain is configured, allow matching the current # host:port exactly if it's permitted by ALLOWED_HOSTS. diff --git a/plain/plain/internal/handlers/base.py b/plain/plain/internal/handlers/base.py index 9b721e63a8..727547f547 100644 --- a/plain/plain/internal/handlers/base.py +++ b/plain/plain/internal/handlers/base.py @@ -13,6 +13,15 @@ logger = logging.getLogger("plain.request") +# These middleware classes are always used by Plain. +BUILTIN_MIDDLEWARE = [ + "plain.internal.middleware.headers.DefaultHeadersMiddleware", + "plain.internal.middleware.https.HttpsRedirectMiddleware", + "plain.internal.middleware.slash.RedirectSlashMiddleware", + "plain.csrf.middleware.CsrfViewMiddleware", +] + + class BaseHandler: _view_middleware = None _middleware_chain = None @@ -27,7 +36,10 @@ def load_middleware(self): get_response = self._get_response handler = convert_exception_to_response(get_response) - for middleware_path in reversed(settings.MIDDLEWARE): + + middlewares = reversed(BUILTIN_MIDDLEWARE + settings.MIDDLEWARE) + + for middleware_path in middlewares: middleware = import_string(middleware_path) mw_instance = middleware(handler) diff --git a/plain/plain/middleware/__init__.py b/plain/plain/internal/middleware/__init__.py similarity index 100% rename from plain/plain/middleware/__init__.py rename to plain/plain/internal/middleware/__init__.py diff --git a/plain/plain/internal/middleware/headers.py b/plain/plain/internal/middleware/headers.py new file mode 100644 index 0000000000..1da30eb254 --- /dev/null +++ b/plain/plain/internal/middleware/headers.py @@ -0,0 +1,19 @@ +from plain.runtime import settings + + +class DefaultHeadersMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + response = self.get_response(request) + + for header, value in settings.DEFAULT_RESPONSE_HEADERS.items(): + response.headers.setdefault(header, value) + + # Add the Content-Length header to non-streaming responses if not + # already set. + if not response.streaming and not response.has_header("Content-Length"): + response.headers["Content-Length"] = str(len(response.content)) + + return response diff --git a/plain/plain/internal/middleware/https.py b/plain/plain/internal/middleware/https.py new file mode 100644 index 0000000000..475479845e --- /dev/null +++ b/plain/plain/internal/middleware/https.py @@ -0,0 +1,36 @@ +import re + +from plain.http import ResponsePermanentRedirect +from plain.runtime import settings + + +class HttpsRedirectMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + # Settings for https (compile regexes once) + self.https_redirect_enabled = settings.HTTPS_REDIRECT_ENABLED + self.https_redirect_host = settings.HTTPS_REDIRECT_HOST + self.https_redirect_exempt = [ + re.compile(r) for r in settings.HTTPS_REDIRECT_EXEMPT + ] + + def __call__(self, request): + """ + Rewrite the URL based on settings.APPEND_SLASH + """ + + if redirect_response := self.maybe_https_redirect(request): + return redirect_response + + return self.get_response(request) + + def maybe_https_redirect(self, request): + path = request.path.lstrip("/") + if ( + self.https_redirect_enabled + and not request.is_https() + and not any(pattern.search(path) for pattern in self.https_redirect_exempt) + ): + host = self.https_redirect_host or request.get_host() + return ResponsePermanentRedirect(f"https://{host}{request.get_full_path()}") diff --git a/plain/plain/middleware/common.py b/plain/plain/internal/middleware/slash.py similarity index 51% rename from plain/plain/middleware/common.py rename to plain/plain/internal/middleware/slash.py index 32818686d3..bb83b73e3d 100644 --- a/plain/plain/middleware/common.py +++ b/plain/plain/internal/middleware/slash.py @@ -1,57 +1,20 @@ -import re - from plain.http import ResponsePermanentRedirect from plain.runtime import settings from plain.urls import is_valid_path from plain.utils.http import escape_leading_slashes -class CommonMiddleware: - """ - "Common" middleware for taking care of some basic operations: - - - Redirecting to HTTPS: Based on the HTTPS_REDIRECT_ENABLED setting, - redirect to HTTPS if the request is not secure. - - - Default response headers: Add default headers to responses. - - - URL rewriting: Based on the APPEND_SLASH setting, - append missing slashes. - - - If APPEND_SLASH is set and the initial URL doesn't end with a - slash, and it is not found in urlpatterns, form a new URL by - appending a slash at the end. If this new URL is found in - urlpatterns, return an HTTP redirect to this new URL; otherwise - process the initial URL as usual. - - This behavior can be customized by subclassing CommonMiddleware and - overriding the response_redirect_class attribute. - """ - - response_redirect_class = ResponsePermanentRedirect - +class RedirectSlashMiddleware: def __init__(self, get_response): self.get_response = get_response - # Settings for https (compile regexes once) - self.https_redirect_enabled = settings.HTTPS_REDIRECT_ENABLED - self.https_redirect_host = settings.HTTPS_REDIRECT_HOST - self.https_redirect_exempt = [ - re.compile(r) for r in settings.HTTPS_REDIRECT_EXEMPT - ] - def __call__(self, request): """ Rewrite the URL based on settings.APPEND_SLASH """ - if redirect_response := self.maybe_https_redirect(request): - return redirect_response - response = self.get_response(request) - self.set_default_headers(response) - """ When the status code of the response is 404, it may redirect to a path with an appended slash if should_redirect_with_slash() returns True. @@ -59,29 +22,10 @@ def __call__(self, request): # If the given URL is "Not Found", then check if we should redirect to # a path with a slash appended. if response.status_code == 404 and self.should_redirect_with_slash(request): - return self.response_redirect_class(self.get_full_path_with_slash(request)) - - # Add the Content-Length header to non-streaming responses if not - # already set. - if not response.streaming and not response.has_header("Content-Length"): - response.headers["Content-Length"] = str(len(response.content)) + return ResponsePermanentRedirect(self.get_full_path_with_slash(request)) return response - def maybe_https_redirect(self, request): - path = request.path.lstrip("/") - if ( - self.https_redirect_enabled - and not request.is_https() - and not any(pattern.search(path) for pattern in self.https_redirect_exempt) - ): - host = self.https_redirect_host or request.get_host() - return ResponsePermanentRedirect(f"https://{host}{request.get_full_path()}") - - def set_default_headers(self, response): - for header, value in settings.DEFAULT_RESPONSE_HEADERS.items(): - response.headers.setdefault(header, value) - def should_redirect_with_slash(self, request): """ Return True if settings.APPEND_SLASH is True and appending a slash to diff --git a/plain/plain/middleware/README.md b/plain/plain/middleware/README.md deleted file mode 100644 index ad97f57f76..0000000000 --- a/plain/plain/middleware/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Middleware - -Hook into the request/response cycle. diff --git a/plain/plain/middleware/gzip.py b/plain/plain/middleware/gzip.py deleted file mode 100644 index 17b71fcaa4..0000000000 --- a/plain/plain/middleware/gzip.py +++ /dev/null @@ -1,64 +0,0 @@ -from plain.utils.cache import patch_vary_headers -from plain.utils.regex_helper import _lazy_re_compile -from plain.utils.text import compress_sequence, compress_string - -re_accepts_gzip = _lazy_re_compile(r"\bgzip\b") - - -class GZipMiddleware: - """ - Compress content if the browser allows gzip compression. - Set the Vary header accordingly, so that caches will base their storage - on the Accept-Encoding header. - """ - - max_random_bytes = 100 - - def __init__(self, get_response): - self.get_response = get_response - - def __call__(self, request): - response = self.get_response(request) - - # It's not worth attempting to compress really short responses. - if not response.streaming and len(response.content) < 200: - return response - - # Avoid gzipping if we've already got a content-encoding. - if response.has_header("Content-Encoding"): - return response - - patch_vary_headers(response, ("Accept-Encoding",)) - - ae = request.META.get("HTTP_ACCEPT_ENCODING", "") - if not re_accepts_gzip.search(ae): - return response - - if response.streaming: - response.streaming_content = compress_sequence( - response.streaming_content, - max_random_bytes=self.max_random_bytes, - ) - # Delete the `Content-Length` header for streaming content, because - # we won't know the compressed size until we stream it. - del response.headers["Content-Length"] - else: - # Return the compressed content only if it's actually shorter. - compressed_content = compress_string( - response.content, - max_random_bytes=self.max_random_bytes, - ) - if len(compressed_content) >= len(response.content): - return response - response.content = compressed_content - response.headers["Content-Length"] = str(len(response.content)) - - # If there is a strong ETag, make it weak to fulfill the requirements - # of RFC 9110 Section 8.8.1 while also allowing conditional request - # matches on ETags. - etag = response.get("ETag") - if etag and etag.startswith('"'): - response.headers["ETag"] = "W/" + etag - response.headers["Content-Encoding"] = "gzip" - - return response diff --git a/plain/plain/preflight/security/csrf.py b/plain/plain/preflight/security/csrf.py index 36b9951452..53b1f474ee 100644 --- a/plain/plain/preflight/security/csrf.py +++ b/plain/plain/preflight/security/csrf.py @@ -32,9 +32,5 @@ def check_csrf_middleware(package_configs, **kwargs): @register(deploy=True) def check_csrf_cookie_secure(package_configs, **kwargs): - passed_check = ( - settings.CSRF_USE_SESSIONS - or not _csrf_middleware() - or settings.CSRF_COOKIE_SECURE is True - ) + passed_check = not _csrf_middleware() or settings.CSRF_COOKIE_SECURE is True return [] if passed_check else [W016] diff --git a/plain/plain/runtime/README.md b/plain/plain/runtime/README.md index 2c7adc7182..4dfc4f6048 100644 --- a/plain/plain/runtime/README.md +++ b/plain/plain/runtime/README.md @@ -55,8 +55,6 @@ DEBUG = environ.get("DEBUG", "false").lower() in ("true", "1", "yes") MIDDLEWARE = [ "plain.sessions.middleware.SessionMiddleware", - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", "plain.auth.middleware.AuthenticationMiddleware", ] diff --git a/plain/plain/runtime/global_settings.py b/plain/plain/runtime/global_settings.py index 2fb3a89bf2..6547c4adec 100644 --- a/plain/plain/runtime/global_settings.py +++ b/plain/plain/runtime/global_settings.py @@ -29,7 +29,7 @@ DEFAULT_CHARSET = "utf-8" # List of strings representing installed packages. -INSTALLED_PACKAGES: list = [] +INSTALLED_PACKAGES: list[str] = [] # Whether to append trailing slashes to URLs. APPEND_SLASH = True @@ -110,10 +110,7 @@ # List of middleware to use. Order is important; in the request phase, these # middleware will be applied in the order given, and in the response # phase the middleware will be applied in reverse order. -MIDDLEWARE = [ - "plain.middleware.common.CommonMiddleware", - "plain.csrf.middleware.CsrfViewMiddleware", -] +MIDDLEWARE: list[str] = [] ########### # SIGNING # @@ -135,7 +132,6 @@ CSRF_COOKIE_SAMESITE = "Lax" CSRF_HEADER_NAME = "HTTP_X_CSRFTOKEN" CSRF_TRUSTED_ORIGINS: list[str] = [] -CSRF_USE_SESSIONS = False ########### # LOGGING #