1
1
from datetime import timedelta
2
+ from logging import getLogger
3
+ from secrets import token_urlsafe
4
+ from typing import List
2
5
3
6
from fastapi import APIRouter , Depends , HTTPException , Request , Security
4
7
from fastapi .responses import HTMLResponse , JSONResponse , RedirectResponse
5
- from pydantic import Field , PrivateAttr
8
+ from pydantic import Field , PrivateAttr , field_validator
6
9
from starlette .status import HTTP_403_FORBIDDEN
7
10
8
11
from csp_gateway .server import GatewayChannels , GatewayModule
9
12
13
+ from ..shared import ChannelSelection
14
+
10
15
# separate to avoid circular
11
16
from ..web import GatewayWebApp
12
17
from .hacks .api_key_middleware_websocket_fix .api_key import (
15
20
APIKeyQuery ,
16
21
)
17
22
23
+ _log = getLogger (__name__ )
18
24
19
- class MountAPIKeyMiddleware (GatewayModule ):
20
- api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
25
+ __all__ = (
26
+ "MountAuthMiddleware" ,
27
+ "MountAPIKeyMiddleware" ,
28
+ )
21
29
22
- # NOTE: don't make this publically configureable
23
- # as it is needed in gateway.py
24
- _api_key_name : str = PrivateAttr ("token" )
25
- _api_key_secret : str = PrivateAttr ("" )
30
+ # TODO: More eventually
31
+
32
+ class MountAuthMiddleware (GatewayModule ):
33
+ enforce : list = Field (default = (), description = "Routes to enforce, default empty means 'all'" )
34
+ channels : ChannelSelection = Field (
35
+ default_factory = ChannelSelection ,
36
+ description = "Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'" ,
37
+ )
38
+
39
+ enforce_controls : bool = Field (default = False , description = "Whether to allow access to controls routes. Defaults to True" )
40
+ enforce_ui : bool = Field (default = True , description = "Whether to allow web access to the API Key authentication routes. Defaults to True" )
26
41
27
42
unauthorized_status_message : str = "unauthorized"
28
43
44
+ _enforced_channels : List [str ] = PrivateAttr (default_factory = list )
45
+
29
46
def connect (self , channels : GatewayChannels ) -> None :
30
47
# NO-OP
31
48
...
32
49
50
+
51
+ class MountAPIKeyMiddleware (MountAuthMiddleware ):
52
+ api_key : str = Field (default = token_urlsafe (32 ), description = "API Key to use" )
53
+ api_key_name : str = Field (default = "token" , description = "API Key to use" )
54
+ api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
55
+
56
+ _instance_count = 0
57
+
58
+ @field_validator ("api_key_name" , mode = "before" )
59
+ @classmethod
60
+ def _validate_api_key_name (cls , value : str ) -> str :
61
+ if not value :
62
+ raise ValueError ("API Key name must be a non-empty string" )
63
+ value = f"{ value .strip ().lower ()} -{ cls ._instance_count } "
64
+ cls ._instance_count += 1
65
+ return value
66
+
33
67
def rest (self , app : GatewayWebApp ) -> None :
34
68
if app .settings .AUTHENTICATE :
35
- # first, pull out the api key secret from the settings
36
- self ._api_key_secret = app .settings .API_KEY
69
+ # Use configuration to determine allowed routes
70
+ # for this API key
71
+ self ._calculate_auth (app )
72
+
73
+ # Setup the routes for authentication
74
+ self ._setup_routes (app )
75
+
76
+
77
+ def _calculate_auth (self , app : GatewayWebApp ) -> None :
78
+ self ._enforced_channels = self .channels .select_from (app .gateway .channels_model )
79
+
80
+ # Fully form the url
81
+ self ._api_str = app .settings .API_STR
37
82
83
+ def _setup_routes (self , app : GatewayWebApp ) -> None :
38
84
# reinitialize header
39
- api_key_query = APIKeyQuery (name = self ._api_key_name , auto_error = False )
40
- api_key_header = APIKeyHeader (name = self ._api_key_name , auto_error = False )
41
- api_key_cookie = APIKeyCookie (name = self ._api_key_name , auto_error = False )
85
+ api_key_query = APIKeyQuery (name = self .api_key_name , auto_error = False )
86
+ api_key_header = APIKeyHeader (name = self .api_key_name , auto_error = False )
87
+ api_key_cookie = APIKeyCookie (name = self .api_key_name , auto_error = False )
42
88
43
89
# routers
44
90
auth_router : APIRouter = app .get_router ("auth" )
45
91
public_router : APIRouter = app .get_router ("public" )
46
92
47
93
# now mount middleware
48
94
async def get_api_key (
95
+ request : Request = None ,
49
96
api_key_query : str = Security (api_key_query ),
50
97
api_key_header : str = Security (api_key_header ),
51
98
api_key_cookie : str = Security (api_key_cookie ),
52
99
):
53
- if api_key_query == self ._api_key_secret or api_key_header == self ._api_key_secret or api_key_cookie == self ._api_key_secret :
54
- return self ._api_key_secret
55
- else :
56
- raise HTTPException (
57
- status_code = HTTP_403_FORBIDDEN ,
58
- detail = self .unauthorized_status_message ,
59
- )
60
-
61
- @auth_router .get ("/login" )
62
- async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
63
- response = RedirectResponse (url = "/" )
64
- response .set_cookie (
65
- self ._api_key_name ,
66
- value = api_key ,
67
- domain = app .settings .AUTHENTICATION_DOMAIN ,
68
- httponly = True ,
69
- max_age = self .api_key_timeout .total_seconds (),
70
- expires = self .api_key_timeout .total_seconds (),
71
- )
72
- return response
73
-
74
- @auth_router .get ("/logout" )
75
- async def route_logout_and_remove_cookie ():
76
- response = RedirectResponse (url = "/login" )
77
- response .delete_cookie (self ._api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
78
- return response
79
-
80
- # I'm hand rolling these for now...
81
- @public_router .get ("/login" , response_class = HTMLResponse , include_in_schema = False )
82
- async def get_login_page (token : str = "" , request : Request = None ):
83
- if token :
84
- if token != "" :
85
- return RedirectResponse (url = f"{ app .settings .API_V1_STR } /auth/login?token={ token } " )
86
- return app .templates .TemplateResponse (
87
- "login.html.j2" ,
88
- {"request" : request , "api_key_name" : self ._api_key_name },
100
+ if request is None :
101
+ # If request is None, we are not in a request context, return None
102
+ _log .warning ("API Key check: request is None, returning None" )
103
+ return None
104
+
105
+ if hasattr (request .state , "auth" ):
106
+ # Already authenticated, return the API key
107
+ _log .info (f"API Key check: already authenticated, returning { self .api_key_name } " )
108
+ return request .state .auth
109
+
110
+ resolved_path = request .url .path .rstrip ("/" ).replace (self ._api_str , "" ).lstrip ("/" ).rsplit ("/" , 1 )
111
+
112
+ if len (resolved_path ) == 1 :
113
+ root = resolved_path [0 ]
114
+ channel = ""
115
+
116
+ elif len (resolved_path ) > 1 :
117
+ root = resolved_path [0 ]
118
+ channel = resolved_path [1 ]
119
+
120
+ if self .enforce and root not in self .enforce :
121
+ # Route not in enforce, allow
122
+ _log .info (f"API Key check: { root } /{ channel } not in enforced list { self .enforce } , allowing" )
123
+ return ""
124
+
125
+ if root == "controls" and not self .enforce_controls :
126
+ # Controls route not enforced, allow
127
+ _log .info (f"API Key check: root { root } not enforced, allowing" )
128
+ return ""
129
+
130
+ # TODO
131
+ if root in ("" , "auth" , "perspective" ) and not self .enforce_ui :
132
+ # UI route not enforced, allow
133
+ _log .info (f"API Key check: root { root } not enforced, allowing" )
134
+ return ""
135
+
136
+ if root not in ("controls" , "auth" , "perspective" ) and channel and channel not in self ._enforced_channels :
137
+ # Channel not in enforce, allow
138
+ _log .info (f"API Key check: channel { root } /{ channel } not in enforced channels { self ._enforced_channels } , allowing" )
139
+ return ""
140
+
141
+ # Else, enforce
142
+ if api_key_query == self .api_key or api_key_header == self .api_key or api_key_cookie == self .api_key :
143
+ # Return the API key secret to allow access
144
+ _log .info (f"API Key check: { self .api_key_name } matched for { root } /{ channel } , allowing access" )
145
+
146
+ # NOTE: only set this if we are the one validating, not if we are ignoring
147
+ request .state .auth = self .api_key
148
+ return self .api_key
149
+
150
+ _log .warning (f"API Key check: { self .api_key_name } did not match, denying access" )
151
+ raise HTTPException (
152
+ status_code = HTTP_403_FORBIDDEN ,
153
+ detail = self .unauthorized_status_message ,
89
154
)
90
155
91
- @public_router .get ("/logout" , response_class = HTMLResponse , include_in_schema = False )
92
- async def get_logout_page (request : Request = None ):
93
- return app .templates .TemplateResponse ("logout.html.j2" , {"request" : request })
94
-
95
156
# add auth to all other routes
96
157
app .add_middleware (Depends (get_api_key ))
97
158
159
+ if self .enforce_ui :
160
+ @auth_router .get ("/login" )
161
+ async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
162
+ if not api_key :
163
+ raise HTTPException (
164
+ status_code = HTTP_403_FORBIDDEN ,
165
+ detail = self .unauthorized_status_message ,
166
+ )
167
+ response = RedirectResponse (url = "/" )
168
+ response .set_cookie (
169
+ self .api_key_name ,
170
+ value = api_key ,
171
+ domain = app .settings .AUTHENTICATION_DOMAIN ,
172
+ httponly = True ,
173
+ max_age = self .api_key_timeout .total_seconds (),
174
+ expires = self .api_key_timeout .total_seconds (),
175
+ )
176
+ return response
177
+
178
+ @auth_router .get ("/logout" )
179
+ async def route_logout_and_remove_cookie ():
180
+ response = RedirectResponse (url = "/login" )
181
+ response .delete_cookie (self .api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
182
+ return response
183
+
184
+ # I'm hand rolling these for now...
185
+ @public_router .get ("/login" , response_class = HTMLResponse , include_in_schema = False )
186
+ async def get_login_page (token : str = "" , request : Request = None ):
187
+ if token and token != "" :
188
+ return RedirectResponse (url = f"{ self ._api_str } /auth/login?token={ token } " )
189
+ return app .templates .TemplateResponse (
190
+ "login.html.j2" ,
191
+ {"request" : request , "api_key_name" : self .api_key_name },
192
+ )
193
+
194
+ @public_router .get ("/logout" , response_class = HTMLResponse , include_in_schema = False )
195
+ async def get_logout_page (request : Request = None ):
196
+ return app .templates .TemplateResponse ("logout.html.j2" , {"request" : request })
197
+
198
+
98
199
@app .app .exception_handler (403 )
99
200
async def custom_403_handler (request : Request = None , * args ):
100
201
if "/api" in request .url .path :
@@ -110,7 +211,7 @@ async def custom_403_handler(request: Request = None, *args):
110
211
"login.html.j2" ,
111
212
{
112
213
"request" : request ,
113
- "api_key_name" : self ._api_key_name ,
214
+ "api_key_name" : self .api_key_name ,
114
215
"status_code" : 403 ,
115
216
"detail" : self .unauthorized_status_message ,
116
217
},
0 commit comments