Skip to content

Commit

Permalink
Rework middleware and add BUILTIN_MIDDLEWARE
Browse files Browse the repository at this point in the history
davegaeddert committed Oct 17, 2024
1 parent 18d4e87 commit aab6311
Showing 20 changed files with 106 additions and 199 deletions.
6 changes: 2 additions & 4 deletions plain-auth/README.md
Original file line number Diff line number Diff line change
@@ -23,10 +23,8 @@ INSTALLED_PACKAGES = [
]

MIDDLEWARE = [
"plain.sessions.middleware.SessionMiddleware", # <--
"plain.middleware.common.CommonMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
"plain.auth.middleware.AuthenticationMiddleware", # <--
"plain.sessions.middleware.SessionMiddleware",
"plain.auth.middleware.AuthenticationMiddleware",
]

AUTH_USER_MODEL = "users.User"
6 changes: 2 additions & 4 deletions plain-auth/plain/auth/README.md
Original file line number Diff line number Diff line change
@@ -21,10 +21,8 @@ INSTALLED_PACKAGES = [
]

MIDDLEWARE = [
"plain.sessions.middleware.SessionMiddleware", # <--
"plain.middleware.common.CommonMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
"plain.auth.middleware.AuthenticationMiddleware", # <--
"plain.sessions.middleware.SessionMiddleware",
"plain.auth.middleware.AuthenticationMiddleware",
]

AUTH_USER_MODEL = "users.User"
2 changes: 0 additions & 2 deletions plain-importmap/test_project/settings.py
Original file line number Diff line number Diff line change
@@ -26,8 +26,6 @@

MIDDLEWARE = [
"plain.sessions.middleware.SessionMiddleware",
"plain.middleware.common.CommonMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
"plain.auth.middleware.AuthenticationMiddleware",
]

2 changes: 1 addition & 1 deletion plain-oauth/tests/provider_tests/test_github.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ def test_github_provider(db, client, settings):
assert response.status_code == 302
assert (
response.url
== "https://github.com/login/oauth/authorize?client_id=test_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fgithub%2Fcallback%2F&response_type=code&scope=user&state=dummy_state"
== "https://github.com/login/oauth/authorize?client_id=test_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fgithub%2Fcallback%2F&response_type=code&scope=user&state=dummy_state"
)

# GitHub redirects to the callback url
8 changes: 4 additions & 4 deletions plain-oauth/tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ def test_dummy_signup(db, client, settings):
assert response.status_code == 302
assert (
response.url
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
)

# Provider redirects to the callback url
@@ -148,7 +148,7 @@ def test_dummy_login_connection(db, client, settings):
assert response.status_code == 302
assert (
response.url
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
)

# Provider redirects to the callback url
@@ -215,7 +215,7 @@ def test_dummy_login_without_connection(db, client, settings):
assert response.status_code == 302
assert (
response.url
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
)

# Provider redirects to the callback url
@@ -253,7 +253,7 @@ def test_dummy_connect(db, client, settings):
assert response.status_code == 302
assert (
response.url
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=http%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
== "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state"
)

# Provider redirects to the callback url
4 changes: 3 additions & 1 deletion plain-sessions/tests/test_sessions.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@
def test_session_created(db, client):
assert Session.objects.count() == 0

client.get("/")
response = client.get("/")

assert response.status_code == 200

assert Session.objects.count() == 1
2 changes: 0 additions & 2 deletions plain-staff/README.md
Original file line number Diff line number Diff line change
@@ -52,8 +52,6 @@ INSTALLED_PACKAGES = [
MIDDLEWARE = [
"plain.sessions.middleware.SessionMiddleware",
"plain.middleware.common.CommonMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
"plain.auth.middleware.AuthenticationMiddleware",
"plain.staff.querystats.QueryStatsMiddleware",
2 changes: 0 additions & 2 deletions plain-staff/plain/staff/README.md
Original file line number Diff line number Diff line change
@@ -50,8 +50,6 @@ INSTALLED_PACKAGES = [
MIDDLEWARE = [
"plain.sessions.middleware.SessionMiddleware",
"plain.middleware.common.CommonMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
"plain.auth.middleware.AuthenticationMiddleware",
"plain.staff.querystats.QueryStatsMiddleware",
2 changes: 0 additions & 2 deletions plain-staff/plain/staff/querystats/README.md
Original file line number Diff line number Diff line change
@@ -27,8 +27,6 @@ INSTALLED_PACKAGES = [

MIDDLEWARE = [
"plain.sessions.middleware.SessionMiddleware",
"plain.middleware.common.CommonMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
"plain.auth.middleware.AuthenticationMiddleware",

"plain.staff.querystats.QueryStatsMiddleware",
59 changes: 21 additions & 38 deletions plain/plain/csrf/middleware.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from collections import defaultdict
from urllib.parse import urlparse

from plain.exceptions import DisallowedHost, ImproperlyConfigured
from plain.exceptions import DisallowedHost
from plain.http import HttpHeaders, UnreadablePostError
from plain.logs import log_response
from plain.runtime import settings
@@ -242,44 +242,31 @@ def _get_secret(self, request):
If the CSRF_USE_SESSIONS setting is false, raises InvalidTokenFormat if
the request's secret has invalid characters or an invalid length.
"""
if settings.CSRF_USE_SESSIONS:
try:
csrf_secret = request.session.get(CSRF_SESSION_KEY)
except AttributeError:
raise ImproperlyConfigured(
"CSRF_USE_SESSIONS is enabled, but request.session is not "
"set. SessionMiddleware must appear before CsrfViewMiddleware "
"in MIDDLEWARE."
)
try:
csrf_secret = request.COOKIES[settings.CSRF_COOKIE_NAME]
except KeyError:
csrf_secret = None
else:
try:
csrf_secret = request.COOKIES[settings.CSRF_COOKIE_NAME]
except KeyError:
csrf_secret = None
else:
# This can raise InvalidTokenFormat.
_check_token_format(csrf_secret)
# This can raise InvalidTokenFormat.
_check_token_format(csrf_secret)

if csrf_secret is None:
return None
return csrf_secret

def _set_csrf_cookie(self, request, response):
if settings.CSRF_USE_SESSIONS:
if request.session.get(CSRF_SESSION_KEY) != request.META["CSRF_COOKIE"]:
request.session[CSRF_SESSION_KEY] = request.META["CSRF_COOKIE"]
else:
response.set_cookie(
settings.CSRF_COOKIE_NAME,
request.META["CSRF_COOKIE"],
max_age=settings.CSRF_COOKIE_AGE,
domain=settings.CSRF_COOKIE_DOMAIN,
path=settings.CSRF_COOKIE_PATH,
secure=settings.CSRF_COOKIE_SECURE,
httponly=settings.CSRF_COOKIE_HTTPONLY,
samesite=settings.CSRF_COOKIE_SAMESITE,
)
# Set the Vary header since content varies with the CSRF cookie.
patch_vary_headers(response, ("Cookie",))
response.set_cookie(
settings.CSRF_COOKIE_NAME,
request.META["CSRF_COOKIE"],
max_age=settings.CSRF_COOKIE_AGE,
domain=settings.CSRF_COOKIE_DOMAIN,
path=settings.CSRF_COOKIE_PATH,
secure=settings.CSRF_COOKIE_SECURE,
httponly=settings.CSRF_COOKIE_HTTPONLY,
samesite=settings.CSRF_COOKIE_SAMESITE,
)
# Set the Vary header since content varies with the CSRF cookie.
patch_vary_headers(response, ("Cookie",))

def _origin_verified(self, request):
request_origin = request.META["HTTP_ORIGIN"]
@@ -331,11 +318,7 @@ def _check_referer(self, request):
):
return
# Allow matching the configured cookie domain.
good_referer = (
settings.SESSION_COOKIE_DOMAIN
if settings.CSRF_USE_SESSIONS
else settings.CSRF_COOKIE_DOMAIN
)
good_referer = settings.CSRF_COOKIE_DOMAIN
if good_referer is None:
# If no cookie domain is configured, allow matching the current
# host:port exactly if it's permitted by ALLOWED_HOSTS.
14 changes: 13 additions & 1 deletion plain/plain/internal/handlers/base.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,15 @@
logger = logging.getLogger("plain.request")


# These middleware classes are always used by Plain.
BUILTIN_MIDDLEWARE = [
"plain.internal.middleware.headers.DefaultHeadersMiddleware",
"plain.internal.middleware.https.HttpsRedirectMiddleware",
"plain.internal.middleware.slash.RedirectSlashMiddleware",
"plain.csrf.middleware.CsrfViewMiddleware",
]


class BaseHandler:
_view_middleware = None
_middleware_chain = None
@@ -27,7 +36,10 @@ def load_middleware(self):

get_response = self._get_response
handler = convert_exception_to_response(get_response)
for middleware_path in reversed(settings.MIDDLEWARE):

middlewares = reversed(BUILTIN_MIDDLEWARE + settings.MIDDLEWARE)

for middleware_path in middlewares:
middleware = import_string(middleware_path)
mw_instance = middleware(handler)

File renamed without changes.
19 changes: 19 additions & 0 deletions plain/plain/internal/middleware/headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from plain.runtime import settings


class DefaultHeadersMiddleware:
def __init__(self, get_response):
self.get_response = get_response

def __call__(self, request):
response = self.get_response(request)

for header, value in settings.DEFAULT_RESPONSE_HEADERS.items():
response.headers.setdefault(header, value)

# Add the Content-Length header to non-streaming responses if not
# already set.
if not response.streaming and not response.has_header("Content-Length"):
response.headers["Content-Length"] = str(len(response.content))

return response
36 changes: 36 additions & 0 deletions plain/plain/internal/middleware/https.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import re

from plain.http import ResponsePermanentRedirect
from plain.runtime import settings


class HttpsRedirectMiddleware:
def __init__(self, get_response):
self.get_response = get_response

# Settings for https (compile regexes once)
self.https_redirect_enabled = settings.HTTPS_REDIRECT_ENABLED
self.https_redirect_host = settings.HTTPS_REDIRECT_HOST
self.https_redirect_exempt = [
re.compile(r) for r in settings.HTTPS_REDIRECT_EXEMPT
]

def __call__(self, request):
"""
Rewrite the URL based on settings.APPEND_SLASH
"""

if redirect_response := self.maybe_https_redirect(request):
return redirect_response

return self.get_response(request)

def maybe_https_redirect(self, request):
path = request.path.lstrip("/")
if (
self.https_redirect_enabled
and not request.is_https()
and not any(pattern.search(path) for pattern in self.https_redirect_exempt)
):
host = self.https_redirect_host or request.get_host()
return ResponsePermanentRedirect(f"https://{host}{request.get_full_path()}")
Original file line number Diff line number Diff line change
@@ -1,87 +1,31 @@
import re

from plain.http import ResponsePermanentRedirect
from plain.runtime import settings
from plain.urls import is_valid_path
from plain.utils.http import escape_leading_slashes


class CommonMiddleware:
"""
"Common" middleware for taking care of some basic operations:
- Redirecting to HTTPS: Based on the HTTPS_REDIRECT_ENABLED setting,
redirect to HTTPS if the request is not secure.
- Default response headers: Add default headers to responses.
- URL rewriting: Based on the APPEND_SLASH setting,
append missing slashes.
- If APPEND_SLASH is set and the initial URL doesn't end with a
slash, and it is not found in urlpatterns, form a new URL by
appending a slash at the end. If this new URL is found in
urlpatterns, return an HTTP redirect to this new URL; otherwise
process the initial URL as usual.
This behavior can be customized by subclassing CommonMiddleware and
overriding the response_redirect_class attribute.
"""

response_redirect_class = ResponsePermanentRedirect

class RedirectSlashMiddleware:
def __init__(self, get_response):
self.get_response = get_response

# Settings for https (compile regexes once)
self.https_redirect_enabled = settings.HTTPS_REDIRECT_ENABLED
self.https_redirect_host = settings.HTTPS_REDIRECT_HOST
self.https_redirect_exempt = [
re.compile(r) for r in settings.HTTPS_REDIRECT_EXEMPT
]

def __call__(self, request):
"""
Rewrite the URL based on settings.APPEND_SLASH
"""

if redirect_response := self.maybe_https_redirect(request):
return redirect_response

response = self.get_response(request)

self.set_default_headers(response)

"""
When the status code of the response is 404, it may redirect to a path
with an appended slash if should_redirect_with_slash() returns True.
"""
# If the given URL is "Not Found", then check if we should redirect to
# a path with a slash appended.
if response.status_code == 404 and self.should_redirect_with_slash(request):
return self.response_redirect_class(self.get_full_path_with_slash(request))

# Add the Content-Length header to non-streaming responses if not
# already set.
if not response.streaming and not response.has_header("Content-Length"):
response.headers["Content-Length"] = str(len(response.content))
return ResponsePermanentRedirect(self.get_full_path_with_slash(request))

return response

def maybe_https_redirect(self, request):
path = request.path.lstrip("/")
if (
self.https_redirect_enabled
and not request.is_https()
and not any(pattern.search(path) for pattern in self.https_redirect_exempt)
):
host = self.https_redirect_host or request.get_host()
return ResponsePermanentRedirect(f"https://{host}{request.get_full_path()}")

def set_default_headers(self, response):
for header, value in settings.DEFAULT_RESPONSE_HEADERS.items():
response.headers.setdefault(header, value)

def should_redirect_with_slash(self, request):
"""
Return True if settings.APPEND_SLASH is True and appending a slash to
3 changes: 0 additions & 3 deletions plain/plain/middleware/README.md

This file was deleted.

Loading

0 comments on commit aab6311

Please sign in to comment.