diff --git a/csp_gateway/server/config/gateway/omnibus.yaml b/csp_gateway/server/config/gateway/omnibus.yaml index 1765988..4d88a25 100644 --- a/csp_gateway/server/config/gateway/omnibus.yaml +++ b/csp_gateway/server/config/gateway/omnibus.yaml @@ -30,6 +30,21 @@ modules: _target_: csp_gateway.MountWebSocketRoutes mount_api_key_middleware: _target_: csp_gateway.MountAPIKeyMiddleware + api_key: 12345 + enforce_ui: false + enforce_controls: false + mount_api_key_middleware_ui: + _target_: csp_gateway.MountAPIKeyMiddleware + api_key: token + enforce: [] + enforce_ui: true + enforce_controls: false + mount_api_key_middleware_controls: + _target_: csp_gateway.MountAPIKeyMiddleware + api_key: 54321 + enforce: [] + enforce_ui: false + enforce_controls: true gateway: _target_: csp_gateway.Gateway @@ -37,7 +52,6 @@ gateway: PORT: ${port} AUTHENTICATE: ${authenticate} UI: True - API_KEY: "12345" modules: - /modules/example_module - /modules/example_module_feedback @@ -49,6 +63,8 @@ gateway: - /modules/mount_rest_routes - /modules/mount_websocket_routes - /modules/mount_api_key_middleware + - /modules/mount_api_key_middleware_ui + - /modules/mount_api_key_middleware_controls channels: _target_: csp_gateway.server.demo.ExampleGatewayChannels diff --git a/csp_gateway/server/demo/config/omnibus.yaml b/csp_gateway/server/demo/config/omnibus.yaml index fea0991..d229d7d 100644 --- a/csp_gateway/server/demo/config/omnibus.yaml +++ b/csp_gateway/server/demo/config/omnibus.yaml @@ -3,7 +3,7 @@ defaults: - /gateway: omnibus - _self_ -# csp-gateway-start --config-dir=csp_gateway/server/omnibus +config=omnibus +# csp-gateway-start --config-dir=csp_gateway/server/demo +config=omnibus -authenticate: False +authenticate: True port: 8000 diff --git a/csp_gateway/server/demo/omnibus.py b/csp_gateway/server/demo/omnibus.py index 351fa50..a66c381 100644 --- a/csp_gateway/server/demo/omnibus.py +++ b/csp_gateway/server/demo/omnibus.py @@ -279,7 +279,7 @@ def push_to_perspective( # type: ignore[no-untyped-def] # be instantiated directly as we do so here: # Setting authentication - settings = GatewaySettings(API_KEY="12345", AUTHENTICATE=False) + settings = GatewaySettings(AUTHENTICATE=False) # instantiate gateway gateway = Gateway( diff --git a/csp_gateway/server/gateway/gateway.py b/csp_gateway/server/gateway/gateway.py index 5bd1f58..aa0bf9d 100644 --- a/csp_gateway/server/gateway/gateway.py +++ b/csp_gateway/server/gateway/gateway.py @@ -275,12 +275,26 @@ def start( log.info("Launching web server on:") url = f"http://{gethostname()}:{self.settings.PORT}" - if ui: - if self.settings.AUTHENTICATE: - log.info(f"\tUI: {url}?token={self.settings.API_KEY}") + if ui and self.settings.AUTHENTICATE: + from ..middleware import MountAPIKeyMiddleware + + # TODO: Will need to handle others + auth = "" + + # Find any middleware enforcing auth + for module in self.modules: + if isinstance(module, MountAPIKeyMiddleware) and module.enforce_ui is True: + auth = module.api_key + break + + if auth: + log.info(f"\tUI: {url}?{module.api_key_name}={auth}") else: log.info(f"\tUI: {url}") + else: + log.info(f"\tUI: {url}") + log.info(f"\tDocs: {url}/docs") log.info(f"\tDocs: {url}/redoc") diff --git a/csp_gateway/server/middleware/__init__.py b/csp_gateway/server/middleware/__init__.py index 96df696..1608a24 100644 --- a/csp_gateway/server/middleware/__init__.py +++ b/csp_gateway/server/middleware/__init__.py @@ -1 +1 @@ -from .api_key import MountAPIKeyMiddleware +from .api_key import * diff --git a/csp_gateway/server/middleware/api_key.py b/csp_gateway/server/middleware/api_key.py index 3c64dad..2d93aa4 100644 --- a/csp_gateway/server/middleware/api_key.py +++ b/csp_gateway/server/middleware/api_key.py @@ -1,12 +1,17 @@ from datetime import timedelta +from logging import getLogger +from secrets import token_urlsafe +from typing import List from fastapi import APIRouter, Depends, HTTPException, Request, Security from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, field_validator from starlette.status import HTTP_403_FORBIDDEN from csp_gateway.server import GatewayChannels, GatewayModule +from ..shared import ChannelSelection + # separate to avoid circular from ..web import GatewayWebApp from .hacks.api_key_middleware_websocket_fix.api_key import ( @@ -15,54 +20,154 @@ APIKeyQuery, ) +_log = getLogger(__name__) -class MountAPIKeyMiddleware(GatewayModule): - api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12)) +__all__ = ( + "MountAuthMiddleware", + "MountAPIKeyMiddleware", +) + +# TODO: More eventually + + +class MountAuthMiddleware(GatewayModule): + enforce: list = Field(default=(), description="Routes to enforce, default empty means 'all'") + channels: ChannelSelection = Field( + default_factory=ChannelSelection, + description="Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'", + ) - # NOTE: don't make this publically configureable - # as it is needed in gateway.py - _api_key_name: str = PrivateAttr("token") - _api_key_secret: str = PrivateAttr("") + enforce_controls: bool = Field(default=False, description="Whether to allow access to controls routes. Defaults to True") + enforce_ui: bool = Field(default=True, description="Whether to allow web access to the API Key authentication routes. Defaults to True") unauthorized_status_message: str = "unauthorized" + _enforced_channels: List[str] = PrivateAttr(default_factory=list) + def connect(self, channels: GatewayChannels) -> None: # NO-OP ... + +class MountAPIKeyMiddleware(MountAuthMiddleware): + api_key: str = Field(default=token_urlsafe(32), description="API Key to use") + api_key_name: str = Field(default="token", description="API Key to use") + api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12)) + + _instance_count = 0 + + @field_validator("api_key_name", mode="before") + @classmethod + def _validate_api_key_name(cls, value: str) -> str: + if not value: + raise ValueError("API Key name must be a non-empty string") + value = f"{value.strip().lower()}-{cls._instance_count}" + cls._instance_count += 1 + return value + def rest(self, app: GatewayWebApp) -> None: if app.settings.AUTHENTICATE: - # first, pull out the api key secret from the settings - self._api_key_secret = app.settings.API_KEY - - # reinitialize header - api_key_query = APIKeyQuery(name=self._api_key_name, auto_error=False) - api_key_header = APIKeyHeader(name=self._api_key_name, auto_error=False) - api_key_cookie = APIKeyCookie(name=self._api_key_name, auto_error=False) - - # routers - auth_router: APIRouter = app.get_router("auth") - public_router: APIRouter = app.get_router("public") - - # now mount middleware - async def get_api_key( - api_key_query: str = Security(api_key_query), - api_key_header: str = Security(api_key_header), - api_key_cookie: str = Security(api_key_cookie), - ): - if api_key_query == self._api_key_secret or api_key_header == self._api_key_secret or api_key_cookie == self._api_key_secret: - return self._api_key_secret - else: + # Use configuration to determine allowed routes + # for this API key + self._calculate_auth(app) + + # Setup the routes for authentication + self._setup_routes(app) + + def _calculate_auth(self, app: GatewayWebApp) -> None: + self._enforced_channels = self.channels.select_from(app.gateway.channels_model) + + # Fully form the url + self._api_str = app.settings.API_STR + + def _setup_routes(self, app: GatewayWebApp) -> None: + # reinitialize header + api_key_query = APIKeyQuery(name=self.api_key_name, auto_error=False) + api_key_header = APIKeyHeader(name=self.api_key_name, auto_error=False) + api_key_cookie = APIKeyCookie(name=self.api_key_name, auto_error=False) + + # routers + auth_router: APIRouter = app.get_router("auth") + public_router: APIRouter = app.get_router("public") + + # now mount middleware + async def get_api_key( + request: Request = None, + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), + api_key_cookie: str = Security(api_key_cookie), + ): + if request is None: + # If request is None, we are not in a request context, return None + _log.warning("API Key check: request is None, returning None") + return None + + if hasattr(request.state, "auth"): + # Already authenticated, return the API key + _log.info(f"API Key check: already authenticated, returning {self.api_key_name}") + return request.state.auth + + resolved_path = request.url.path.rstrip("/").replace(self._api_str, "").lstrip("/").rsplit("/", 1) + + if len(resolved_path) == 1: + root = resolved_path[0] + channel = "" + + elif len(resolved_path) > 1: + root = resolved_path[0] + channel = resolved_path[1] + + if self.enforce and root not in self.enforce: + # Route not in enforce, allow + _log.info(f"API Key check: {root}/{channel} not in enforced list {self.enforce}, allowing") + return "" + + if root == "controls" and not self.enforce_controls: + # Controls route not enforced, allow + _log.info(f"API Key check: root {root} not enforced, allowing") + return "" + + # TODO + if root in ("", "auth", "perspective") and not self.enforce_ui: + # UI route not enforced, allow + _log.info(f"API Key check: root {root} not enforced, allowing") + return "" + + if root not in ("controls", "auth", "perspective") and channel and channel not in self._enforced_channels: + # Channel not in enforce, allow + _log.info(f"API Key check: channel {root}/{channel} not in enforced channels {self._enforced_channels}, allowing") + return "" + + # Else, enforce + if api_key_query == self.api_key or api_key_header == self.api_key or api_key_cookie == self.api_key: + # Return the API key secret to allow access + _log.info(f"API Key check: {self.api_key_name} matched for {root}/{channel}, allowing access") + + # NOTE: only set this if we are the one validating, not if we are ignoring + request.state.auth = self.api_key + return self.api_key + + _log.warning(f"API Key check: {self.api_key_name} did not match, denying access") + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail=self.unauthorized_status_message, + ) + + # add auth to all other routes + app.add_middleware(Depends(get_api_key)) + + if self.enforce_ui: + + @auth_router.get("/login") + async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)): + if not api_key: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail=self.unauthorized_status_message, ) - - @auth_router.get("/login") - async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)): response = RedirectResponse(url="/") response.set_cookie( - self._api_key_name, + self.api_key_name, value=api_key, domain=app.settings.AUTHENTICATION_DOMAIN, httponly=True, @@ -74,44 +179,40 @@ async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)): @auth_router.get("/logout") async def route_logout_and_remove_cookie(): response = RedirectResponse(url="/login") - response.delete_cookie(self._api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN) + response.delete_cookie(self.api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN) return response # I'm hand rolling these for now... @public_router.get("/login", response_class=HTMLResponse, include_in_schema=False) async def get_login_page(token: str = "", request: Request = None): - if token: - if token != "": - return RedirectResponse(url=f"{app.settings.API_V1_STR}/auth/login?token={token}") + if token and token != "": + return RedirectResponse(url=f"{self._api_str}/auth/login?token={token}") return app.templates.TemplateResponse( "login.html.j2", - {"request": request, "api_key_name": self._api_key_name}, + {"request": request, "api_key_name": self.api_key_name}, ) @public_router.get("/logout", response_class=HTMLResponse, include_in_schema=False) async def get_logout_page(request: Request = None): return app.templates.TemplateResponse("logout.html.j2", {"request": request}) - # add auth to all other routes - app.add_middleware(Depends(get_api_key)) - - @app.app.exception_handler(403) - async def custom_403_handler(request: Request = None, *args): - if "/api" in request.url.path: - # programmatic api access, return json - return JSONResponse( - { - "detail": self.unauthorized_status_message, - "status_code": 403, - }, - status_code=403, - ) - return app.templates.TemplateResponse( - "login.html.j2", + @app.app.exception_handler(403) + async def custom_403_handler(request: Request = None, *args): + if "/api" in request.url.path: + # programmatic api access, return json + return JSONResponse( { - "request": request, - "api_key_name": self._api_key_name, - "status_code": 403, "detail": self.unauthorized_status_message, + "status_code": 403, }, + status_code=403, ) + return app.templates.TemplateResponse( + "login.html.j2", + { + "request": request, + "api_key_name": self.api_key_name, + "status_code": 403, + "detail": self.unauthorized_status_message, + }, + ) diff --git a/csp_gateway/server/settings.py b/csp_gateway/server/settings.py index 1c5aff5..3266d9d 100644 --- a/csp_gateway/server/settings.py +++ b/csp_gateway/server/settings.py @@ -1,4 +1,3 @@ -from secrets import token_urlsafe from socket import gethostname from typing import List @@ -31,8 +30,4 @@ class Settings(BaseSettings): UI: bool = Field(False, description="Enables ui in the web application") AUTHENTICATE: bool = Field(False, description="Whether to authenticate users for access to the web application") - API_KEY: str = Field( - token_urlsafe(32), - description="The API key for access if `AUTHENTICATE=True`. The default is auto-generated, but a user-provided value can be used.", - ) AUTHENTICATION_DOMAIN: str = gethostname()