diff --git a/.github/workflows/eslint.yml b/.github/workflows/eslint.yml index 816be2d..760cf93 100644 --- a/.github/workflows/eslint.yml +++ b/.github/workflows/eslint.yml @@ -12,7 +12,9 @@ jobs: name: ESLint Check steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 + with: + submodules: true - name: Install dependencies working-directory: ./react_frontend run: yarn diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 09099e1..817dd94 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -12,11 +12,13 @@ jobs: name: Get newest code and run pylint steps: - name: Checkout repository - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/checkout@v4 with: - python-version: 3.9 + submodules: true + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.11.4 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -24,4 +26,4 @@ jobs: pip install pylint - name: Analysing the code with pylint run: | - python3 -m pylint `find . -type f | grep .py$ | grep -v tests/ | xargs` --disable=R0201,R0903 + python3 -m pylint `find . -type f | grep .py$ | grep -v -e tests/ -e core_lib/ -e venv/ | xargs` --disable=R0903 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 72df0c9..2ddaeb0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,10 @@ jobs: name: Check for new commits runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: true - name: Get the newest hash and latest release hash id: tag_and_commit_hash_chech @@ -93,8 +96,10 @@ jobs: with: node-version: '14' - - uses: actions/checkout@v2 + - name: Checkout repository + uses: actions/checkout@v4 with: + submodules: true path: ValDB - name: Save timestamp in release diff --git a/.github/workflows/tslint.yml b/.github/workflows/tslint.yml index 9508372..7f6c278 100644 --- a/.github/workflows/tslint.yml +++ b/.github/workflows/tslint.yml @@ -12,7 +12,9 @@ jobs: name: TSLint Check steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 + with: + submodules: true - name: Install dependencies working-directory: ./react_frontend run: yarn diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e97fd55 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "core_lib"] + path = core_lib + url = https://github.com/cms-PdmV/PdmVWebCore.git diff --git a/api/user.py b/api/user.py index 90b6872..135d794 100644 --- a/api/user.py +++ b/api/user.py @@ -11,6 +11,7 @@ from utils.logger import LoggerManager from core import Namespace from core.database import get_database +from core_lib.middlewares.auth import UserInfo as MiddlewareUserInfo from data.group import get_all_groups from models.user import User, UserRole from lookup.user_group import UserGroupLookup @@ -81,8 +82,9 @@ def get(self): ''' Get current user info from request ''' - email = session.get('user').get('email') - fullname = session.get('user').get('fullname') + user_data: MiddlewareUserInfo = session.get("user") + email = user_data.email + fullname = user_data.fullname _logger.info('Checking if user is already registered') existed_user = User.get_by_email(email=email) if not existed_user: diff --git a/core_lib b/core_lib new file mode 160000 index 0000000..8a31dea --- /dev/null +++ b/core_lib @@ -0,0 +1 @@ +Subproject commit 8a31dea389e5bac3c2ab875c425a5ac2a57f291a diff --git a/main.py b/main.py index 42c586f..80f45c8 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,6 @@ Main module for valdb ''' import os -from werkzeug.middleware.proxy_fix import ProxyFix from flask import Flask, request, session, render_template from jinja2.exceptions import TemplateNotFound from flask_cors import CORS @@ -12,7 +11,7 @@ from api.static import serve_file from database.index import database_index_setup from lookup.user_group import UserGroupLookup -from middlewares.auth import AuthenticationMiddleware +from core_lib.middlewares.auth import AuthenticationMiddleware load_dotenv() @@ -30,9 +29,6 @@ # Set secret key for session cookie app.secret_key = os.getenv('SECRET_KEY') -# Handle redirections from a reverse proxy -app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) - # Enable CORS CORS( app, @@ -41,13 +37,8 @@ ) # Enable OIDC authentication -auth: AuthenticationMiddleware = AuthenticationMiddleware( - app=app, - client_id=os.getenv('CLIENT_ID'), - client_secret=os.getenv('CLIENT_SECRET'), - home_endpoint="catch_all" -) -app.before_request(lambda: auth(request=request, session=session)) +auth = AuthenticationMiddleware(app=app) +app.before_request(lambda: auth.authenticate(request=request, flask_session=session)) api.init_app(app) diff --git a/middlewares/auth.py b/middlewares/auth.py deleted file mode 100644 index 033ca78..0000000 --- a/middlewares/auth.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -This module implements an authentication middleware to -register a client to handle OAuth 2.0 authentication requests. -""" -import os -import re -import jwt -from jwt.exceptions import InvalidTokenError, ExpiredSignatureError -from authlib.integrations.flask_client import OAuth -from werkzeug.exceptions import HTTPException -from flask.sessions import SessionMixin -from flask import ( - Flask, - Blueprint, - Request, - Response, - session, - redirect, - url_for, - jsonify, -) - - -class AuthenticationMiddleware: - """ - This AuthenticationMiddleware sets OAuth 2.0 authentication for a Flask application - It handles the authentication by JWT access token and it is able to refresh - expired access tokens if a refresh token is available. - :param app: Flask application - :param client_id: OAuth 2.0 Application Client ID - :param client_secret: OAuth 2.0 Application Client Secret - :param valid_audiences: Authorized audiences (applications) whose tokens are - accepted by the web server - """ - - OIDC_CONFIG_DEFAULT: str = ( - "https://auth.cern.ch/auth/realms/cern/.well-known/openid-configuration" - ) - JWT_PUBLIC_KEY_URL: str = ( - "https://auth.cern.ch/auth/realms/cern/protocol/openid-connect/certs" - ) - JWT_REGEX_PATTERN: str = ( - r"eyJ([a-zA-Z0-9_=]+)\.([a-zA-Z0-9_=]+)\.([a-zA-Z0-9_\-\+\/=]*)" - ) - - def __init__( - self, - app: Flask, - client_id: str, - client_secret: str, - home_endpoint: str, - valid_audiences: list[str] = [], - ): - self.oidc_config: str = os.getenv( - "REALM_OIDC_CONFIG", AuthenticationMiddleware.OIDC_CONFIG_DEFAULT - ) - self.jwt_public_key_url: str = os.getenv( - "REALM_PUBLIC_KEY_URL", AuthenticationMiddleware.JWT_PUBLIC_KEY_URL - ) - self.jwt_regex_pattern: str = AuthenticationMiddleware.JWT_REGEX_PATTERN - self.jwt_regex = re.compile(self.jwt_regex_pattern) - self.app: Flask = self.__configure_session_cookie_security(app=app) - self.home_endpoint: str = home_endpoint - self.client_id: str = client_id - self.client_secret: str = client_secret - self.valid_audiences: list[str] = ( - [self.client_id] if not valid_audiences else valid_audiences - ) - self.jwk: jwt.PyJWK = self.__retrieve_jwk() - self.oauth_client: OAuth = self.__register_oauth_client() - self.oauth_blueprint: Blueprint = self.__register_blueprint() - - def __auth(self) -> Response: - """ - This endpoint starts the communication with the OAuth 2.0 Authorization Server - to request an access and refresh token. - """ - redirect_uri: str = url_for("oauth.callback", _external=True) - return self.oauth_client.cern.authorize_redirect(redirect_uri) - - def __callback(self) -> Response: - """ - This endpoint handles the callback from the OAuth 2.0 Authorization Server and - stores the access and refresh tokens inside a cookie handled by the Flask. - Also, this endpoint redirects the user back to its original destination. - """ - try: - token = self.oauth_client.cern.authorize_access_token() - session["token"] = { - "access_token": token["access_token"], - "refresh_token": token["refresh_token"], - } - original_destination: str = session.pop( - "next", default=url_for(self.home_endpoint) - ) - return redirect(original_destination) - except Exception: - return redirect(url_for(self.home_endpoint)) - - def __configure_session_cookie_security(self, app: Flask) -> Flask: - """ - Restrict the access to the session cookie. - The session cookie is going to be used to store the JWT token to authenticate the user, - the user data decrypted and the next endpoint the user is going to be redirected after a successful authentication. - Based on Flask documentation, the session cookie is cryptographically signed when it is transmitted to the - client web browser. For more information, please see: https://flask.palletsprojects.com/en/2.2.x/quickstart/?highlight=session#sessions - :return: Flask application with session cookie security set - :rtype: Flask - """ - # Configure the session cookie - app.config["SESSION_COOKIE_SAMESITE"] = "None" - app.config["SESSION_COOKIE_HTTPONLY"] = True - app.config["SESSION_COOKIE_SECURE"] = True - return app - - def __register_blueprint(self) -> Blueprint: - """ - Register a submodule (blueprint) inside the Flask application to - handle OAuth authentication. The new submodule is registered under the - /oauth2 url prefix. - :return Flask submodule (blueprint) - :rtype Blueprint - """ - oauth_blueprint = Blueprint("oauth", __name__) - # Register views - oauth_blueprint.add_url_rule( - rule="/auth", endpoint="auth", view_func=self.__auth - ) - oauth_blueprint.add_url_rule( - rule="/callback", endpoint="callback", view_func=self.__callback - ) - # Register OAuth submodule into the application - self.app.register_blueprint(blueprint=oauth_blueprint, url_prefix="/oauth2") - return oauth_blueprint - - def __register_oauth_client(self) -> OAuth: - """ - Register the OAuth 2.0 Client into the Flask application used to build token claim - requests. - :return OAuth 2.0 Client - :rtype OAuth - """ - # Set the client id and secret - client_credentials: dict = { - "CERN_CLIENT_ID": self.client_id, - "CERN_CLIENT_SECRET": self.client_secret, - } - - # Update the application to include this environment variables - self.app.config.from_mapping(client_credentials) - - # Register CERN Realm - oauth_client: OAuth = OAuth(app=self.app) - oauth_client.register( - name="cern", - server_metadata_url=self.oidc_config, - client_kwargs={ - "scope": "openid profile email", - }, - ) - return oauth_client - - def __retrieve_jwk(self) -> jwt.PyJWK: - """ - Retrieve the public key from the OAuth 2.0 Authorization Server used to - validate JWT access token. - :return JWK to validate JWT access token - :rtype PyJWK - """ - jwks_client = jwt.PyJWKClient(self.jwt_public_key_url) - return jwks_client.get_signing_keys()[0] - - def __token_to_user(self, decoded_token: dict) -> dict: - """ - Parse the user data included inside the JWT access token - and return the user information. - :return CERN user information - :rtype dict - """ - username: str = decoded_token.get("sub", "") - roles: list[str] = decoded_token.get("cern_roles", []) - email: str = decoded_token.get("email", "") - # Lowercase the email - email = email.lower() - - given_name: str = decoded_token.get("given_name", "") - family_name: str = decoded_token.get("family_name", "") - fullname: str = decoded_token.get("name", "") - return { - "username": username, - "roles": roles, - "email": email, - "given_name": given_name, - "family_name": family_name, - "fullname": fullname, - } - - def __decode_token(self, access_token: str) -> dict | None: - """ - Decodes a JWT access token and validates it using a JWK and the - valid audiences. - :raises ExpiredSignatureError: If the access token is expired - :raises HTTPException: If the access token was signed by an invalid provider, - if the token audience is not valid, - or if the claim dates are not valid - For more details, please see: - https://pyjwt.readthedocs.io/en/latest/api.html#exceptions - :return CERN user data included inside the JWT access token - :rtype dict - """ - jwt_raw_token = self.jwt_regex.search(access_token) - if jwt_raw_token: - raw_token = jwt_raw_token[0] - try: - decoded_token: dict = jwt.decode( - jwt=raw_token, - key=self.jwk.key, - audience=self.valid_audiences, - algorithms=["RS256"], - ) - return self.__token_to_user(decoded_token) - except ExpiredSignatureError as expired_error: - raise expired_error - except InvalidTokenError as token_error: - msg: str = ( - "The provided JWT token is invalid - " f"Details: {token_error}" - ) - error: dict = {"error": msg} - response: Response = jsonify(error) - response.status_code = 401 - raise HTTPException(description=msg, response=response) - - return None - - def __retrieve_token_from_session(self, session: SessionMixin) -> dict | None: - """ - Retrieves the access and refresh tokens from a cookie via Flask session. - Also, it attemps to refresh the access token if a refresh token is available - in case the access token has expired. - :return CERN user data included inside the JWT access token - :return None if there is no access token available inside the session cookie - or if there was an error while renewing the access token. - This None value indicates that an interactive authentication is required - :rtype dict | None - """ - session_cookie: dict = session.get("token") - if session_cookie: - access_token: str = session_cookie.get("access_token", "") - try: - user_info: dict | None = self.__decode_token(access_token=access_token) - return user_info - except ExpiredSignatureError: - # Try to refresh the token via refresh token claim - try: - refresh_token: str = session_cookie.get("refresh_token", "") - new_token: dict = self.oauth_client.cern.fetch_access_token( - refresh_token=refresh_token, grant_type="refresh_token" - ) - # Update the new token - new_access_token: str = new_token.get("access_token", "") - session["token"].update(new_token) - return self.__decode_token(access_token=new_access_token) - except Exception: - # Maybe the refresh token expired - # Force an interactive authentication - session.pop("token") - return None - return None - - def __retrieve_token_from_request(self, request: Request) -> dict | None: - """ - Retrieves the access token from the Authorization header - then it validates the token and retrieve the user data available inside . - :return CERN user data included inside the JWT access token - :return None if there is no access token available inside the Authorization header - or if the access token is expired. - This None value indicates that an interactive authentication is required - :rtype dict | None - """ - access_token = request.headers.get("Authorization") - if access_token: - try: - return self.__decode_token(access_token=access_token) - except ExpiredSignatureError: - # We are not able to retrieve a refresh token - # using the Authorization header - return None - return None - - def __call__(self, request: Request, session: SessionMixin) -> Response | None: - """ - Validate the access token and force a token request if necessary. - :return None if there is a valid access token to authenticate the user or if the user is - performing an authentication process. - - For more details, please see: - https://flask.palletsprojects.com/en/2.2.x/api/?highlight=environ#flask.Flask.before_request - - Otherwise, it will redirect the user to sign in for an interactive authentication. - """ - valid_auth_endpoints = ("oauth.auth", "oauth.callback") - if request.endpoint in valid_auth_endpoints: - # The user is performing an authentication process - # This is usefull when you require to install the middleware - # on the top of the Flask application. @before_request function - # is called before any view, therefore, this could lead to infinite - # redirect loops. - return None - - user_data: dict | None = self.__retrieve_token_from_request(request=request) - if user_data: - session["user"] = user_data - return None - # Check if authentication comes from a cookie session - user_data = self.__retrieve_token_from_session(session=session) - if user_data: - session["user"] = user_data - return None - - # Redirect to authentication endpoint: - # Store this information inside the session - original_destination: str = request.url - session["next"] = original_destination - - redirect_uri: str = url_for(endpoint="oauth.auth") - return redirect(location=redirect_uri) diff --git a/models/user.py b/models/user.py index 1b815f0..ec6ffc0 100644 --- a/models/user.py +++ b/models/user.py @@ -7,6 +7,7 @@ from core.validation import regex, required from core import Model +from core_lib.middlewares.auth import UserInfo as MiddlewareUserInfo EMAIL_FORMAT = r'(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)' @@ -49,8 +50,8 @@ def get_from_session(cls, session): ''' Get user from request ''' - email = session.get('user').get('email') - return cls.get_by_email(email) + user_data: MiddlewareUserInfo = session.get("user") + return cls.get_by_email(user_data.email) def requires(self, roles): ''' diff --git a/utils/logger.py b/utils/logger.py index c904dfd..a0b0030 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -8,6 +8,7 @@ import pathlib from logging.handlers import TimedRotatingFileHandler from flask import request, session, has_request_context +from core_lib.middlewares.auth import UserInfo as MiddlewareUserInfo LOGS_FOLDER: str = os.getenv("LOGS_FOLDER") or f"{os.getcwd()}/logs/" @@ -32,9 +33,9 @@ def return_request_data(self) -> dict: extra_data: dict = {"origin": "", "email": ""} if has_request_context(): request_origin = request.remote_addr - email = session.get("user", {}).get("email") + user_data: MiddlewareUserInfo = session.get("user") extra_data["origin"] = request_origin - extra_data["email"] = email + extra_data["email"] = user_data.email return extra_data def fill_remaining(self, record: logging.LogRecord) -> logging.LogRecord: diff --git a/utils/user.py b/utils/user.py index f05693d..a3029f7 100644 --- a/utils/user.py +++ b/utils/user.py @@ -3,6 +3,7 @@ ''' from werkzeug.exceptions import Forbidden from models.user import User +from core_lib.middlewares.auth import UserInfo as MiddlewareUserInfo def require_permission(session, roles, from_sso: bool = False): ''' @@ -11,8 +12,9 @@ def require_permission(session, roles, from_sso: bool = False): Raise Forbidden if do not have premission ''' if from_sso: + user_data: MiddlewareUserInfo = session.get("user") required_group = set(roles) - user_egroups = set(session.get('user').get('roles')) + user_egroups = set(user_data.roles) if required_group - user_egroups: raise Forbidden()