Skip to content

Commit a29ebd2

Browse files
committed
[WIP] Allow for multiple api key middlewares
Signed-off-by: Tim Paine <[email protected]>
1 parent 8fb8f68 commit a29ebd2

File tree

7 files changed

+192
-66
lines changed

7 files changed

+192
-66
lines changed

csp_gateway/server/config/gateway/omnibus.yaml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,28 @@ modules:
3030
_target_: csp_gateway.MountWebSocketRoutes
3131
mount_api_key_middleware:
3232
_target_: csp_gateway.MountAPIKeyMiddleware
33+
api_key: 12345
34+
enforce_ui: false
35+
enforce_controls: false
36+
mount_api_key_middleware_ui:
37+
_target_: csp_gateway.MountAPIKeyMiddleware
38+
api_key: token
39+
enforce: []
40+
enforce_ui: true
41+
enforce_controls: false
42+
mount_api_key_middleware_controls:
43+
_target_: csp_gateway.MountAPIKeyMiddleware
44+
api_key: 54321
45+
enforce: []
46+
enforce_ui: false
47+
enforce_controls: true
3348

3449
gateway:
3550
_target_: csp_gateway.Gateway
3651
settings:
3752
PORT: ${port}
3853
AUTHENTICATE: ${authenticate}
3954
UI: True
40-
API_KEY: "12345"
4155
modules:
4256
- /modules/example_module
4357
- /modules/example_module_feedback
@@ -49,6 +63,8 @@ gateway:
4963
- /modules/mount_rest_routes
5064
- /modules/mount_websocket_routes
5165
- /modules/mount_api_key_middleware
66+
- /modules/mount_api_key_middleware_ui
67+
- /modules/mount_api_key_middleware_controls
5268
channels:
5369
_target_: csp_gateway.server.demo.ExampleGatewayChannels
5470

csp_gateway/server/demo/config/omnibus.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defaults:
33
- /gateway: omnibus
44
- _self_
55

6-
# csp-gateway-start --config-dir=csp_gateway/server/omnibus +config=omnibus
6+
# csp-gateway-start --config-dir=csp_gateway/server/demo +config=omnibus
77

8-
authenticate: False
8+
authenticate: True
99
port: 8000

csp_gateway/server/demo/omnibus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def push_to_perspective( # type: ignore[no-untyped-def]
279279
# be instantiated directly as we do so here:
280280

281281
# Setting authentication
282-
settings = GatewaySettings(API_KEY="12345", AUTHENTICATE=False)
282+
settings = GatewaySettings(AUTHENTICATE=False)
283283

284284
# instantiate gateway
285285
gateway = Gateway(

csp_gateway/server/gateway/gateway.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,26 @@ def start(
275275
log.info("Launching web server on:")
276276
url = f"http://{gethostname()}:{self.settings.PORT}"
277277

278-
if ui:
279-
if self.settings.AUTHENTICATE:
280-
log.info(f"\tUI: {url}?token={self.settings.API_KEY}")
278+
if ui and self.settings.AUTHENTICATE:
279+
from ..middleware import MountAPIKeyMiddleware
280+
281+
# TODO: Will need to handle others
282+
auth = ""
283+
284+
# Find any middleware enforcing auth
285+
for module in self.modules:
286+
if isinstance(module, MountAPIKeyMiddleware) and module.enforce_ui == True:
287+
auth = module.api_key
288+
break
289+
290+
if auth:
291+
log.info(f"\tUI: {url}?{module.api_key_name}={auth}")
281292
else:
282293
log.info(f"\tUI: {url}")
283294

295+
else:
296+
log.info(f"\tUI: {url}")
297+
284298
log.info(f"\tDocs: {url}/docs")
285299
log.info(f"\tDocs: {url}/redoc")
286300

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .api_key import MountAPIKeyMiddleware
1+
from .api_key import *

csp_gateway/server/middleware/api_key.py

Lines changed: 154 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from datetime import timedelta
2+
from logging import getLogger
3+
from secrets import token_urlsafe
4+
from typing import List
25

36
from fastapi import APIRouter, Depends, HTTPException, Request, Security
47
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
5-
from pydantic import Field, PrivateAttr
8+
from pydantic import Field, PrivateAttr, field_validator
69
from starlette.status import HTTP_403_FORBIDDEN
710

811
from csp_gateway.server import GatewayChannels, GatewayModule
912

13+
from ..shared import ChannelSelection
14+
1015
# separate to avoid circular
1116
from ..web import GatewayWebApp
1217
from .hacks.api_key_middleware_websocket_fix.api_key import (
@@ -15,86 +20,182 @@
1520
APIKeyQuery,
1621
)
1722

23+
_log = getLogger(__name__)
1824

19-
class MountAPIKeyMiddleware(GatewayModule):
20-
api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12))
25+
__all__ = (
26+
"MountAuthMiddleware",
27+
"MountAPIKeyMiddleware",
28+
)
2129

22-
# NOTE: don't make this publically configureable
23-
# as it is needed in gateway.py
24-
_api_key_name: str = PrivateAttr("token")
25-
_api_key_secret: str = PrivateAttr("")
30+
# TODO: More eventually
31+
32+
class MountAuthMiddleware(GatewayModule):
33+
enforce: list = Field(default=(), description="Routes to enforce, default empty means 'all'")
34+
channels: ChannelSelection = Field(
35+
default_factory=ChannelSelection,
36+
description="Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'",
37+
)
38+
39+
enforce_controls: bool = Field(default=False, description="Whether to allow access to controls routes. Defaults to True")
40+
enforce_ui: bool = Field(default=True, description="Whether to allow web access to the API Key authentication routes. Defaults to True")
2641

2742
unauthorized_status_message: str = "unauthorized"
2843

44+
_enforced_channels: List[str] = PrivateAttr(default_factory=list)
45+
2946
def connect(self, channels: GatewayChannels) -> None:
3047
# NO-OP
3148
...
3249

50+
51+
class MountAPIKeyMiddleware(MountAuthMiddleware):
52+
api_key: str = Field(default=token_urlsafe(32), description="API Key to use")
53+
api_key_name: str = Field(default="token", description="API Key to use")
54+
api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12))
55+
56+
_instance_count = 0
57+
58+
@field_validator("api_key_name", mode="before")
59+
@classmethod
60+
def _validate_api_key_name(cls, value: str) -> str:
61+
if not value:
62+
raise ValueError("API Key name must be a non-empty string")
63+
value = f"{value.strip().lower()}-{cls._instance_count}"
64+
cls._instance_count += 1
65+
return value
66+
3367
def rest(self, app: GatewayWebApp) -> None:
3468
if app.settings.AUTHENTICATE:
35-
# first, pull out the api key secret from the settings
36-
self._api_key_secret = app.settings.API_KEY
69+
# Use configuration to determine allowed routes
70+
# for this API key
71+
self._calculate_auth(app)
72+
73+
# Setup the routes for authentication
74+
self._setup_routes(app)
75+
76+
77+
def _calculate_auth(self, app: GatewayWebApp) -> None:
78+
self._enforced_channels = self.channels.select_from(app.gateway.channels_model)
79+
80+
# Fully form the url
81+
self._api_str = app.settings.API_STR
3782

83+
def _setup_routes(self, app: GatewayWebApp) -> None:
3884
# reinitialize header
39-
api_key_query = APIKeyQuery(name=self._api_key_name, auto_error=False)
40-
api_key_header = APIKeyHeader(name=self._api_key_name, auto_error=False)
41-
api_key_cookie = APIKeyCookie(name=self._api_key_name, auto_error=False)
85+
api_key_query = APIKeyQuery(name=self.api_key_name, auto_error=False)
86+
api_key_header = APIKeyHeader(name=self.api_key_name, auto_error=False)
87+
api_key_cookie = APIKeyCookie(name=self.api_key_name, auto_error=False)
4288

4389
# routers
4490
auth_router: APIRouter = app.get_router("auth")
4591
public_router: APIRouter = app.get_router("public")
4692

4793
# now mount middleware
4894
async def get_api_key(
95+
request: Request = None,
4996
api_key_query: str = Security(api_key_query),
5097
api_key_header: str = Security(api_key_header),
5198
api_key_cookie: str = Security(api_key_cookie),
5299
):
53-
if api_key_query == self._api_key_secret or api_key_header == self._api_key_secret or api_key_cookie == self._api_key_secret:
54-
return self._api_key_secret
55-
else:
56-
raise HTTPException(
57-
status_code=HTTP_403_FORBIDDEN,
58-
detail=self.unauthorized_status_message,
59-
)
60-
61-
@auth_router.get("/login")
62-
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
63-
response = RedirectResponse(url="/")
64-
response.set_cookie(
65-
self._api_key_name,
66-
value=api_key,
67-
domain=app.settings.AUTHENTICATION_DOMAIN,
68-
httponly=True,
69-
max_age=self.api_key_timeout.total_seconds(),
70-
expires=self.api_key_timeout.total_seconds(),
71-
)
72-
return response
73-
74-
@auth_router.get("/logout")
75-
async def route_logout_and_remove_cookie():
76-
response = RedirectResponse(url="/login")
77-
response.delete_cookie(self._api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN)
78-
return response
79-
80-
# I'm hand rolling these for now...
81-
@public_router.get("/login", response_class=HTMLResponse, include_in_schema=False)
82-
async def get_login_page(token: str = "", request: Request = None):
83-
if token:
84-
if token != "":
85-
return RedirectResponse(url=f"{app.settings.API_V1_STR}/auth/login?token={token}")
86-
return app.templates.TemplateResponse(
87-
"login.html.j2",
88-
{"request": request, "api_key_name": self._api_key_name},
100+
if request is None:
101+
# If request is None, we are not in a request context, return None
102+
_log.warning("API Key check: request is None, returning None")
103+
return None
104+
105+
if hasattr(request.state, "auth"):
106+
# Already authenticated, return the API key
107+
_log.info(f"API Key check: already authenticated, returning {self.api_key_name}")
108+
return request.state.auth
109+
110+
resolved_path = request.url.path.rstrip("/").replace(self._api_str, "").lstrip("/").rsplit("/", 1)
111+
112+
if len(resolved_path) == 1:
113+
root = resolved_path[0]
114+
channel = ""
115+
116+
elif len(resolved_path) > 1:
117+
root = resolved_path[0]
118+
channel = resolved_path[1]
119+
120+
if self.enforce and root not in self.enforce:
121+
# Route not in enforce, allow
122+
_log.info(f"API Key check: {root}/{channel} not in enforced list {self.enforce}, allowing")
123+
return ""
124+
125+
if root == "controls" and not self.enforce_controls:
126+
# Controls route not enforced, allow
127+
_log.info(f"API Key check: root {root} not enforced, allowing")
128+
return ""
129+
130+
# TODO
131+
if root in ("", "auth", "perspective") and not self.enforce_ui:
132+
# UI route not enforced, allow
133+
_log.info(f"API Key check: root {root} not enforced, allowing")
134+
return ""
135+
136+
if root not in ("controls", "auth", "perspective") and channel and channel not in self._enforced_channels:
137+
# Channel not in enforce, allow
138+
_log.info(f"API Key check: channel {root}/{channel} not in enforced channels {self._enforced_channels}, allowing")
139+
return ""
140+
141+
# Else, enforce
142+
if api_key_query == self.api_key or api_key_header == self.api_key or api_key_cookie == self.api_key:
143+
# Return the API key secret to allow access
144+
_log.info(f"API Key check: {self.api_key_name} matched for {root}/{channel}, allowing access")
145+
146+
# NOTE: only set this if we are the one validating, not if we are ignoring
147+
request.state.auth = self.api_key
148+
return self.api_key
149+
150+
_log.warning(f"API Key check: {self.api_key_name} did not match, denying access")
151+
raise HTTPException(
152+
status_code=HTTP_403_FORBIDDEN,
153+
detail=self.unauthorized_status_message,
89154
)
90155

91-
@public_router.get("/logout", response_class=HTMLResponse, include_in_schema=False)
92-
async def get_logout_page(request: Request = None):
93-
return app.templates.TemplateResponse("logout.html.j2", {"request": request})
94-
95156
# add auth to all other routes
96157
app.add_middleware(Depends(get_api_key))
97158

159+
if self.enforce_ui:
160+
@auth_router.get("/login")
161+
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
162+
if not api_key:
163+
raise HTTPException(
164+
status_code=HTTP_403_FORBIDDEN,
165+
detail=self.unauthorized_status_message,
166+
)
167+
response = RedirectResponse(url="/")
168+
response.set_cookie(
169+
self.api_key_name,
170+
value=api_key,
171+
domain=app.settings.AUTHENTICATION_DOMAIN,
172+
httponly=True,
173+
max_age=self.api_key_timeout.total_seconds(),
174+
expires=self.api_key_timeout.total_seconds(),
175+
)
176+
return response
177+
178+
@auth_router.get("/logout")
179+
async def route_logout_and_remove_cookie():
180+
response = RedirectResponse(url="/login")
181+
response.delete_cookie(self.api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN)
182+
return response
183+
184+
# I'm hand rolling these for now...
185+
@public_router.get("/login", response_class=HTMLResponse, include_in_schema=False)
186+
async def get_login_page(token: str = "", request: Request = None):
187+
if token and token != "":
188+
return RedirectResponse(url=f"{self._api_str}/auth/login?token={token}")
189+
return app.templates.TemplateResponse(
190+
"login.html.j2",
191+
{"request": request, "api_key_name": self.api_key_name},
192+
)
193+
194+
@public_router.get("/logout", response_class=HTMLResponse, include_in_schema=False)
195+
async def get_logout_page(request: Request = None):
196+
return app.templates.TemplateResponse("logout.html.j2", {"request": request})
197+
198+
98199
@app.app.exception_handler(403)
99200
async def custom_403_handler(request: Request = None, *args):
100201
if "/api" in request.url.path:
@@ -110,7 +211,7 @@ async def custom_403_handler(request: Request = None, *args):
110211
"login.html.j2",
111212
{
112213
"request": request,
113-
"api_key_name": self._api_key_name,
214+
"api_key_name": self.api_key_name,
114215
"status_code": 403,
115216
"detail": self.unauthorized_status_message,
116217
},

csp_gateway/server/settings.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from secrets import token_urlsafe
21
from socket import gethostname
32
from typing import List
43

@@ -31,8 +30,4 @@ class Settings(BaseSettings):
3130

3231
UI: bool = Field(False, description="Enables ui in the web application")
3332
AUTHENTICATE: bool = Field(False, description="Whether to authenticate users for access to the web application")
34-
API_KEY: str = Field(
35-
token_urlsafe(32),
36-
description="The API key for access if `AUTHENTICATE=True`. The default is auto-generated, but a user-provided value can be used.",
37-
)
3833
AUTHENTICATION_DOMAIN: str = gethostname()

0 commit comments

Comments
 (0)