diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 208cbea2..5aa44fc0 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -38,8 +38,8 @@ jobs: run: | pip install poetry poetry install --with=code_quality,docs,test,types - - name: Lint with flake8 - run: poetry run flake8 + - name: Lint with ruff + run: poetry run ruff check . - name: Run mypy run: poetry run mypy foca - name: Start MongoDB diff --git a/Makefile b/Makefile index 3e2d953c..61dbe0c9 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ cv: clean-venv .PHONY: format-lint fl format-lint: @echo "\nRunning linter and formatter using ruff and typos +++++++++++++++++++++++++++++\n" - @poetry run flake8 + @poetry run ruff format && poetry run ruff check --fix @typos . fl: format-lint diff --git a/docs/api/conf.py b/docs/api/conf.py index 3c25d7c2..a0abc2d3 100644 --- a/docs/api/conf.py +++ b/docs/api/conf.py @@ -22,9 +22,9 @@ # -- Project information ----------------------------------------------------- -project = 'FOCA' -copyright = '2022, ELIXIR Cloud & AAI' -author = 'ELIXIR Cloud & AAI' +project = "FOCA" +copyright = "2022, ELIXIR Cloud & AAI" +author = "ELIXIR Cloud & AAI" # The full version, including alpha/beta/rc tags release = __version__ # noqa: F821 @@ -36,28 +36,28 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # Default doc to search for -master_doc = 'index' +master_doc = "index" # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -67,18 +67,13 @@ # -- Automation ------------------------------------------------------------- + # Auto-generate API doc def run_apidoc(_): - ignore_paths = [ - ] - argv = [ - "--force", - "--module-first", - "-o", "./modules", - "../../foca" - ] + ignore_paths + ignore_paths = [] + argv = ["--force", "--module-first", "-o", "./modules", "../../foca"] + ignore_paths apidoc.main(argv) def setup(app): - app.connect('builder-inited', run_apidoc) + app.connect("builder-inited", run_apidoc) diff --git a/examples/petstore-access-control/app.py b/examples/petstore-access-control/app.py index 9d250aad..96bb98a8 100644 --- a/examples/petstore-access-control/app.py +++ b/examples/petstore-access-control/app.py @@ -2,9 +2,7 @@ from foca import Foca -if __name__ == '__main__': - foca = Foca( - config_file="config.yaml" - ) +if __name__ == "__main__": + foca = Foca(config_file="config.yaml") app = foca.create_app() app.run() diff --git a/examples/petstore-access-control/controllers.py b/examples/petstore-access-control/controllers.py index cfc6a3f7..5cb4d199 100644 --- a/examples/petstore-access-control/controllers.py +++ b/examples/petstore-access-control/controllers.py @@ -2,13 +2,11 @@ import logging -from flask import (current_app, make_response) +from flask import current_app, make_response from pymongo.collection import Collection from exceptions import NotFound -from foca.security.access_control.register_access_control import ( - check_permissions -) +from foca.security.access_control.register_access_control import check_permissions logger = logging.getLogger(__name__) @@ -16,48 +14,48 @@ @check_permissions def findPets(limit=None, tags=None): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore-access-control'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore-access-control"] + .collections["pets"] + .client ) - filter_dict = {} if tags is None else {'tag': {'$in': tags}} + filter_dict = {} if tags is None else {"tag": {"$in": tags}} if not limit: limit = 0 - records = db_collection.find( - filter_dict, - {'_id': False} - ).sort([('$natural', -1)]).limit(limit) + records = ( + db_collection.find(filter_dict, {"_id": False}) + .sort([("$natural", -1)]) + .limit(limit) + ) return list(records) @check_permissions def addPet(pet): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore-access-control'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore-access-control"] + .collections["pets"] + .client ) counter = 0 - ctr = db_collection.find({}).sort([('$natural', -1)]) + ctr = db_collection.find({}).sort([("$natural", -1)]) if not db_collection.count_documents({}) == 0: - counter = ctr[0].get('id') + 1 - record = { - "id": counter, - "name": pet['name'], - "tag": pet['tag'] - } + counter = ctr[0].get("id") + 1 + record = {"id": counter, "name": pet["name"], "tag": pet["tag"]} db_collection.insert_one(record) - del record['_id'] + del record["_id"] return record @check_permissions def findPetById(id): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore-access-control'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore-access-control"] + .collections["pets"] + .client ) record = db_collection.find_one( {"id": id}, - {'_id': False}, + {"_id": False}, ) if record is None: raise NotFound @@ -67,17 +65,18 @@ def findPetById(id): @check_permissions def deletePet(id): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore-access-control'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore-access-control"] + .collections["pets"] + .client ) record = db_collection.find_one( {"id": id}, - {'_id': False}, + {"_id": False}, ) if record is None: raise NotFound db_collection.delete_one( {"id": id}, ) - response = make_response('', 204) + response = make_response("", 204) return response diff --git a/examples/petstore-access-control/exceptions.py b/examples/petstore-access-control/exceptions.py index 1658341a..4dadc724 100644 --- a/examples/petstore-access-control/exceptions.py +++ b/examples/petstore-access-control/exceptions.py @@ -45,9 +45,7 @@ "code": 401, }, Forbidden: { - "message": ( - "Sorry, but you don't have permission to play with the pets." - ), + "message": ("Sorry, but you don't have permission to play with the pets."), "code": 403, }, NotFound: { diff --git a/examples/petstore/app.py b/examples/petstore/app.py index 9d250aad..96bb98a8 100644 --- a/examples/petstore/app.py +++ b/examples/petstore/app.py @@ -2,9 +2,7 @@ from foca import Foca -if __name__ == '__main__': - foca = Foca( - config_file="config.yaml" - ) +if __name__ == "__main__": + foca = Foca(config_file="config.yaml") app = foca.create_app() app.run() diff --git a/examples/petstore/controllers.py b/examples/petstore/controllers.py index b8fad832..0f330d08 100644 --- a/examples/petstore/controllers.py +++ b/examples/petstore/controllers.py @@ -2,7 +2,7 @@ import logging -from flask import (current_app, make_response) +from flask import current_app, make_response from pymongo.collection import Collection from exceptions import NotFound @@ -12,46 +12,40 @@ def findPets(limit=None, tags=None): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore"].collections["pets"].client ) - filter_dict = {} if tags is None else {'tag': {'$in': tags}} + filter_dict = {} if tags is None else {"tag": {"$in": tags}} if not limit: limit = 0 - records = db_collection.find( - filter_dict, - {'_id': False} - ).sort([('$natural', -1)]).limit(limit) + records = ( + db_collection.find(filter_dict, {"_id": False}) + .sort([("$natural", -1)]) + .limit(limit) + ) return list(records) def addPet(pet): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore"].collections["pets"].client ) counter = 0 - ctr = db_collection.find({}).sort([('$natural', -1)]) + ctr = db_collection.find({}).sort([("$natural", -1)]) if not db_collection.count_documents({}) == 0: - counter = ctr[0].get('id') + 1 - record = { - "id": counter, - "name": pet['name'], - "tag": pet['tag'] - } + counter = ctr[0].get("id") + 1 + record = {"id": counter, "name": pet["name"], "tag": pet["tag"]} db_collection.insert_one(record) - del record['_id'] + del record["_id"] return record def findPetById(id): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore"].collections["pets"].client ) record = db_collection.find_one( {"id": id}, - {'_id': False}, + {"_id": False}, ) if record is None: raise NotFound @@ -60,17 +54,16 @@ def findPetById(id): def deletePet(id): db_collection: Collection = ( - current_app.config.foca.db.dbs['petstore'] - .collections['pets'].client + current_app.config.foca.db.dbs["petstore"].collections["pets"].client ) record = db_collection.find_one( {"id": id}, - {'_id': False}, + {"_id": False}, ) if record is None: raise NotFound db_collection.delete_one( {"id": id}, ) - response = make_response('', 204) + response = make_response("", 204) return response diff --git a/foca/api/register_openapi.py b/foca/api/register_openapi.py index c058effd..d0767af1 100644 --- a/foca/api/register_openapi.py +++ b/foca/api/register_openapi.py @@ -14,14 +14,23 @@ # Path Item object fields which contain an Operation object (ie: HTTP verbs). # Reference: https://swagger.io/specification/v3/#path-item-object -_OPERATION_OBJECT_FIELDS = frozenset({ - "get", "put", "post", "delete", "options", "head", "patch", "trace", -}) +_OPERATION_OBJECT_FIELDS = frozenset( + { + "get", + "put", + "post", + "delete", + "options", + "head", + "patch", + "trace", + } +) def register_openapi( - app: App, - specs: List[SpecConfig], + app: App, + specs: List[SpecConfig], ) -> App: """ Register OpenAPI specifications with Connexion application instance. @@ -41,7 +50,6 @@ def register_openapi( """ # Iterate over OpenAPI specs for spec in specs: - # Merge specs list_specs = [spec.path] if isinstance(spec.path, Path) else spec.path spec_parsed: Dict = ConfigParser.merge_yaml(*list_specs) @@ -56,26 +64,24 @@ def register_openapi( # Add/replace fields to Operation Objects if spec.add_operation_fields is not None: for key, val in spec.add_operation_fields.items(): - for path_item_object in spec_parsed.get('paths', {}).values(): + for path_item_object in spec_parsed.get("paths", {}).values(): for operation, operation_object in path_item_object.items(): # noqa: E501 if operation not in _OPERATION_OBJECT_FIELDS: continue operation_object[key] = val - logger.debug( - f"Added operation fields: {spec.add_operation_fields}" - ) + logger.debug(f"Added operation fields: {spec.add_operation_fields}") # Add fields to security definitions/schemes if not spec.disable_auth and spec.add_security_fields is not None: for key, val in spec.add_security_fields.items(): # OpenAPI 2 - sec_defs = spec_parsed.get('securityDefinitions', {}) + sec_defs = spec_parsed.get("securityDefinitions", {}) for sec_def in sec_defs.values(): sec_def[key] = val # OpenAPI 3 sec_schemes = spec_parsed.get( - 'components', {'securitySchemes': {}} - ).get('securitySchemes', {}) # type: ignore + "components", {"securitySchemes": {}} + ).get("securitySchemes", {}) # type: ignore for sec_scheme in sec_schemes.values(): sec_scheme[key] = val logger.debug(f"Added security fields: {spec.add_security_fields}") @@ -83,14 +89,14 @@ def register_openapi( # Remove security definitions/schemes and fields elif spec.disable_auth: # Open API 2 - spec_parsed.pop('securityDefinitions', None) + spec_parsed.pop("securityDefinitions", None) # Open API 3 - spec_parsed.get('components', {}).pop('securitySchemes', None) + spec_parsed.get("components", {}).pop("securitySchemes", None) # Open API 2/3 - spec_parsed.pop('security', None) - for path_item_object in spec_parsed.get('paths', {}).values(): + spec_parsed.pop("security", None) + for path_item_object in spec_parsed.get("paths", {}).values(): for operation_object in path_item_object.values(): - operation_object.pop('security', None) + operation_object.pop("security", None) logger.debug("Removed security fields") # Attach specs to connexion App @@ -98,7 +104,7 @@ def register_openapi( spec.connexion = {} if spec.connexion is None else spec.connexion app.add_api( specification=spec_parsed, - **spec.model_dump().get('connexion', {}), + **spec.model_dump().get("connexion", {}), ) logger.info(f"API endpoints added from spec: {spec.path_out}") diff --git a/foca/config/config_parser.py b/foca/config/config_parser.py index 619afbcf..4a2e6c06 100644 --- a/foca/config/config_parser.py +++ b/foca/config/config_parser.py @@ -4,18 +4,18 @@ import logging from logging.config import dictConfig from pathlib import Path -from typing import (Dict, Optional) +from typing import Dict, Optional from addict import Dict as Addict from pydantic import BaseModel import yaml -from foca.models.config import (Config, LogConfig) +from foca.models.config import Config, LogConfig logger = logging.getLogger(__name__) -class ConfigParser(): +class ConfigParser: """Parse FOCA config files. Args: @@ -51,7 +51,7 @@ def __init__( self, config_file: Optional[Path] = None, custom_config_model: Optional[str] = None, - format_logs: bool = True + format_logs: bool = True, ) -> None: """Constructor method.""" if config_file is not None: @@ -61,10 +61,10 @@ def __init__( if custom_config_model is not None: setattr( self.config, - 'custom', + "custom", self.parse_custom_config( model=custom_config_model, - ) + ), ) if format_logs: self._configure_logging() @@ -100,13 +100,9 @@ def parse_yaml(conf: Path) -> Dict: try: return yaml.safe_load(config_file) except yaml.YAMLError as exc: - raise ValueError( - f"file '{conf}' is not valid YAML" - ) from exc + raise ValueError(f"file '{conf}' is not valid YAML") from exc except OSError as exc: - raise OSError( - f"file '{conf}' could not be read" - ) from exc + raise OSError(f"file '{conf}' could not be read") from exc @staticmethod def merge_yaml(*args: Path) -> Dict: @@ -170,7 +166,8 @@ def parse_custom_config(self, model: str) -> BaseModel: ) try: custom_config = model_class( # type: ignore[operator] - **self.config.custom) # type: ignore[attr-defined] + **self.config.custom + ) # type: ignore[attr-defined] except Exception as exc: raise ValueError( "failed validating custom configuration: provided custom " diff --git a/foca/database/register_mongodb.py b/foca/database/register_mongodb.py index 4433515d..b4566735 100644 --- a/foca/database/register_mongodb.py +++ b/foca/database/register_mongodb.py @@ -30,7 +30,6 @@ def register_mongodb( # Iterate over databases if conf.dbs is not None: for db_name, db_conf in conf.dbs.items(): - # Instantiate PyMongo client mongo = _create_mongo_client( app=app, @@ -45,36 +44,24 @@ def register_mongodb( # Add collections if db_conf.collections is not None and db_conf.client is not None: for coll_name, coll_conf in db_conf.collections.items(): - coll_conf.client = db_conf.client[coll_name] - logger.info( - f"Added database collection '{coll_name}'." - ) + logger.info(f"Added database collection '{coll_name}'.") # Add indexes - if ( - coll_conf.indexes is not None - and coll_conf.client is not None - ): + if coll_conf.indexes is not None and coll_conf.client is not None: # Remove already created indexes if any coll_conf.client.drop_indexes() for index in coll_conf.indexes: if index.keys is not None: coll_conf.client.create_index( - index.keys, **index.options) - logger.info( - f"Indexes created for collection '{coll_name}'." - ) + index.keys, **index.options + ) + logger.info(f"Indexes created for collection '{coll_name}'.") return conf -def add_new_database( - app: Flask, - conf: MongoConfig, - db_conf: DBConfig, - db_name: str -): +def add_new_database(app: Flask, conf: MongoConfig, db_conf: DBConfig, db_name: str): """Register an additional db to database config. Args: @@ -99,18 +86,15 @@ def add_new_database( # Add collections if db_conf.collections is not None and db_conf.client is not None: for coll_name, coll_conf in db_conf.collections.items(): - coll_conf.client = db_conf.client[coll_name] - logger.info( - f"Added database collection '{coll_name}'." - ) + logger.info(f"Added database collection '{coll_name}'.") def _create_mongo_client( - app: Flask, - host: str = 'mongodb', - port: int = 27017, - db: str = 'database', + app: Flask, + host: str = "mongodb", + port: int = 27017, + db: str = "database", ) -> PyMongo: """Create MongoDB client for Flask application instance. @@ -123,30 +107,30 @@ def _create_mongo_client( Returns: MongoDB client for Flask application instance. """ - auth = '' - user = os.environ.get('MONGO_USERNAME') + auth = "" + user = os.environ.get("MONGO_USERNAME") if user is not None and user != "": - auth = '{username}:{password}@'.format( - username=os.environ.get('MONGO_USERNAME'), - password=os.environ.get('MONGO_PASSWORD'), + auth = "{username}:{password}@".format( + username=os.environ.get("MONGO_USERNAME"), + password=os.environ.get("MONGO_PASSWORD"), ) - app.config['MONGO_URI'] = 'mongodb://{auth}{host}:{port}/{db}'.format( - host=os.environ.get('MONGO_HOST', host), - port=os.environ.get('MONGO_PORT', port), - db=os.environ.get('MONGO_DBNAME', db), - auth=auth + app.config["MONGO_URI"] = "mongodb://{auth}{host}:{port}/{db}".format( + host=os.environ.get("MONGO_HOST", host), + port=os.environ.get("MONGO_PORT", port), + db=os.environ.get("MONGO_DBNAME", db), + auth=auth, ) mongo = PyMongo(app) logger.info( ( "Registered database '{db}' at URI '{host}':'{port}' with Flask " - 'application.' + "application." ).format( - db=os.environ.get('MONGO_DBNAME', db), - host=os.environ.get('MONGO_HOST', host), - port=os.environ.get('MONGO_PORT', port) + db=os.environ.get("MONGO_DBNAME", db), + host=os.environ.get("MONGO_HOST", host), + port=os.environ.get("MONGO_PORT", port), ) ) return mongo diff --git a/foca/errors/exceptions.py b/foca/errors/exceptions.py index 0a62223e..8a8d0f5d 100644 --- a/foca/errors/exceptions.py +++ b/foca/errors/exceptions.py @@ -3,7 +3,7 @@ from copy import deepcopy import logging from traceback import format_exception -from typing import (Dict, List) +from typing import Dict, List from connexion import App from connexion.exceptions import ( @@ -12,7 +12,7 @@ OAuthProblem, Unauthorized, ) -from flask import (current_app, Response) +from flask import current_app, Response from json import dumps from werkzeug.exceptions import ( BadRequest, @@ -73,7 +73,7 @@ GatewayTimeout: { "title": "Gateway Timeout", "status": 504, - } + }, } @@ -107,12 +107,8 @@ def _exc_to_str( Returns: String representation of exception. """ - exc_lines = format_exception( - exc.__class__, - exc, - exc.__traceback__ - ) - exc_stripped = [e.rstrip('\n') for e in exc_lines] + exc_lines = format_exception(exc.__class__, exc, exc.__traceback__) + exc_stripped = [e.rstrip("\n") for e in exc_lines] exc_split = [] for item in exc_stripped: exc_split.extend(item.splitlines()) @@ -121,7 +117,7 @@ def _exc_to_str( def _log_exception( exc: BaseException, - format: str = 'oneline', + format: str = "oneline", ) -> None: """Log exception with indicated format. @@ -134,11 +130,11 @@ def _log_exception( or ``regular`` (exception logged with entire trace stack, typically across multiple lines). """ - exc_str = '' + exc_str = "" valid_formats = [ - 'oneline', - 'minimal', - 'regular', + "oneline", + "minimal", + "regular", ] if format in valid_formats: if format == "oneline": @@ -146,10 +142,7 @@ def _log_exception( elif format == "minimal": exc_str = f"{type(exc).__name__}: {str(exc)}" else: - exc_str = _exc_to_str( - exc=exc, - delimiter='\n' - ) + exc_str = _exc_to_str(exc=exc, delimiter="\n") logger.error(exc_str) else: logger.error("Error logging is misconfigured.") @@ -216,16 +209,15 @@ def _problem_handler_json(exception: Exception) -> Response: if exc not in conf.mapping: exc = Exception try: - status = int(_get_by_path( - obj=conf.mapping[exc], - key_sequence=conf.status_member, - )) + status = int( + _get_by_path( + obj=conf.mapping[exc], + key_sequence=conf.status_member, + ) + ) except KeyError: if conf.logging.value != "none": - _log_exception( - exc=exception, - format=conf.logging.value - ) + _log_exception(exc=exception, format=conf.logging.value) return Response( status=500, mimetype="application/problem+json", @@ -233,25 +225,26 @@ def _problem_handler_json(exception: Exception) -> Response: # Log exception JSON & traceback if conf.logging.value != "none": logger.error(conf.mapping[exc]) - _log_exception( - exc=exception, - format=conf.logging.value - ) + _log_exception(exc=exception, format=conf.logging.value) # Filter members to be returned to user keep = deepcopy(conf.mapping[exc]) if conf.public_members is not None: keep = {} for member in deepcopy(conf.public_members): - keep.update(_subset_nested_dict( - obj=conf.mapping[exc], - key_sequence=member, - )) + keep.update( + _subset_nested_dict( + obj=conf.mapping[exc], + key_sequence=member, + ) + ) elif conf.private_members is not None: for member in deepcopy(conf.private_members): - keep.update(_exclude_key_nested_dict( - obj=keep, - key_sequence=member, - )) + keep.update( + _exclude_key_nested_dict( + obj=keep, + key_sequence=member, + ) + ) # Return response return Response( response=dumps(keep), diff --git a/foca/factories/celery_app.py b/foca/factories/celery_app.py index 271783ef..5e136562 100644 --- a/foca/factories/celery_app.py +++ b/foca/factories/celery_app.py @@ -28,18 +28,19 @@ def create_celery_app(app: Flask) -> Celery: backend=conf.backend, include=conf.include, ) - calling_module = ':'.join([stack()[1].filename, stack()[1].function]) + calling_module = ":".join([stack()[1].filename, stack()[1].function]) logger.debug(f"Celery app created from '{calling_module}'.") # Update Celery app configuration with Flask app configuration - setattr(celery.conf, 'foca', app.config.foca) # type: ignore[attr-defined] - logger.debug('Celery app configured.') + setattr(celery.conf, "foca", app.config.foca) # type: ignore[attr-defined] + logger.debug("Celery app configured.") class ContextTask(celery.Task): # type: ignore # https://github.com/python/mypy/issues/4284) """Create subclass of task that wraps task execution in application context. """ + def __call__(self, *args, **kwargs): """Wrap task execution in application context.""" with app.app_context(): # pragma: no cover diff --git a/foca/factories/connexion_app.py b/foca/factories/connexion_app.py index b39b7f6c..62620292 100644 --- a/foca/factories/connexion_app.py +++ b/foca/factories/connexion_app.py @@ -27,7 +27,7 @@ def create_connexion_app(config: Optional[Config] = None) -> App: skip_error_handlers=True, ) - calling_module = ':'.join([stack()[1].filename, stack()[1].function]) + calling_module = ":".join([stack()[1].filename, stack()[1].function]) logger.debug(f"Connexion app created from '{calling_module}'.") # Configure Connexion app @@ -62,16 +62,16 @@ def __add_config_to_connexion_app( app.debug = conf.debug # replace Flask app settings - app.app.config['DEBUG'] = conf.debug - app.app.config['ENV'] = conf.environment - app.app.config['TESTING'] = conf.testing + app.app.config["DEBUG"] = conf.debug + app.app.config["ENV"] = conf.environment + app.app.config["TESTING"] = conf.testing - logger.debug('Flask app settings:') - for (key, value) in app.app.config.items(): - logger.debug('* {}: {}'.format(key, value)) + logger.debug("Flask app settings:") + for key, value in app.app.config.items(): + logger.debug("* {}: {}".format(key, value)) # Add user configuration to Flask app config - setattr(app.app.config, 'foca', config) + setattr(app.app.config, "foca", config) - logger.debug('Connexion app configured.') + logger.debug("Connexion app configured.") return app diff --git a/foca/models/config.py b/foca/models/config.py index 121e2619..cc75b3e1 100644 --- a/foca/models/config.py +++ b/foca/models/config.py @@ -29,7 +29,7 @@ from foca.security.access_control.constants import ( ACCESS_CONTROL_BASE_PATH, - DEFAULT_MODEL_FILE + DEFAULT_MODEL_FILE, ) @@ -52,10 +52,7 @@ def _validate_log_level_choices(cls, level: int) -> int: return level -def _get_by_path( - obj: Dict, - key_sequence: List[str] -) -> Any: +def _get_by_path(obj: Dict, key_sequence: List[str]) -> Any: """Access a nested dictionary by sequence of keys. Args: @@ -79,6 +76,7 @@ class ExceptionLoggingEnum(Enum): regular: The exception is logged with the entire traceback stack, typically on multiple lines. """ + minimal = "minimal" none = "none" regular = "regular" @@ -93,6 +91,7 @@ class ValidationMethodsEnum(Enum): userinfo: JWT validation via OpenID Connect-compliant identity provider's ``/userinfo`` endpoint. """ + public_key = "public_key" userinfo = "userinfo" @@ -107,6 +106,7 @@ class ValidationChecksEnum(Enum): any: Any method is sufficient to validate the JWT; validation succeeds after the first successful check. """ + all = "all" any = "any" @@ -126,6 +126,7 @@ class PymongoDirectionEnum(Enum): HASHED: Index specifier for a hashed index. TEXT: Index specifier for a text index. """ + ASCENDING = 1 DESCENDING = -1 GEO2D = "2d" @@ -137,7 +138,8 @@ class PymongoDirectionEnum(Enum): class FOCABaseConfig(BaseModel): """Base configuration for FOCA models.""" - model_config = ConfigDict(extra='forbid', arbitrary_types_allowed=True) + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) class ServerConfig(FOCABaseConfig): @@ -192,6 +194,7 @@ class ServerConfig(FOCABaseConfig): ServerConfig(host='0.0.0.0', port=8080, debug=True, environment='devel\ opment', testing=False, use_reloader=True) """ + host: str = "0.0.0.0" port: int = 8080 debug: bool = True @@ -309,6 +312,7 @@ class ExceptionConfig(FOCABaseConfig): ': 'Service Unavailable', 'status': 502}, : {'title': 'Gateway Timeout', 'status': 504}}) """ + required_members: List[List[str]] = [["title"], ["status"]] extension_members: Union[bool, List[List[str]]] = False status_member: List[str] = ["status"] @@ -366,14 +370,11 @@ def validate_exceptions_mapping(self) -> Self: # Ensure that status member is among required members if self.status_member not in self.required_members: - raise ValueError( - "Status member is not among required members." - ) + raise ValueError("Status member is not among required members.") # Filter members that are returned to the user - if ( - isinstance(self.public_members, list) and - isinstance(self.private_members, list) + if isinstance(self.public_members, list) and isinstance( + self.private_members, list ): raise ValueError( "Both public and private member filters are active, but at " @@ -400,10 +401,9 @@ def validate_exceptions_mapping(self) -> Self: # Ensure that each exception fulfills all requirements for key, val in exc_dict.items(): - # Keys are exceptions try: - getattr(key, '__cause__') + getattr(key, "__cause__") except AttributeError as exc: raise ValueError( f"Key '{key}' in 'exceptions' dictionary does not appear " @@ -597,6 +597,7 @@ class SpecConfig(FOCABaseConfig): token', 'x-some-other-custom-field': 'some_value'}, disable_auth=False, connex\ ion=None) """ + path: Union[Path, List[Path]] path_out: Optional[Path] = None append: Optional[List[Dict]] = None @@ -612,14 +613,8 @@ def make_paths_absolute_and_set_path_out(self) -> Self: Returns: Model instance with absolute paths and output path set. """ - paths = ( - self.path if isinstance(self.path, list) - else [self.path] - ) - self.path = [ - path if path.is_absolute() - else path.resolve() for path in paths - ] + paths = self.path if isinstance(self.path, list) else [self.path] + self.path = [path if path.is_absolute() else path.resolve() for path in paths] if self.path_out is None: _path = self.path[0].resolve() self.path_out = _path.parent / f"{_path.stem}.modified.yaml" @@ -653,6 +648,7 @@ class APIConfig(FOCABaseConfig): ath_out=PosixPath('/path/to/specs.modified.yaml'), append=None, add_operation_\ fields=None, add_security_fields=None, disable_auth=False, connexion=None)]) """ + specs: List[SpecConfig] = [] @@ -703,6 +699,7 @@ class AccessControlConfig(FOCABaseConfig): db', collection_name='access_control_collection', model='/path/to/policy.co\ nf', owner_headers={'X-User', 'X-Group'}, user_headers={'X-User'}) """ + api_specs: Optional[str] = None api_controllers: Optional[str] = None db_name: Optional[str] = None @@ -711,7 +708,7 @@ class AccessControlConfig(FOCABaseConfig): owner_headers: Optional[set] = None user_headers: Optional[set] = None - @field_validator('model', mode='before') + @field_validator("model", mode="before") @classmethod def validate_model_path(cls, v: Optional[str]) -> str: """Validate the model path. @@ -726,10 +723,7 @@ def validate_model_path(cls, v: Optional[str]) -> str: """ if v is None: - with resource_path( - ACCESS_CONTROL_BASE_PATH, - DEFAULT_MODEL_FILE - ) as _path: + with resource_path(ACCESS_CONTROL_BASE_PATH, DEFAULT_MODEL_FILE) as _path: return str(_path) model_path = Path(v) @@ -802,6 +796,7 @@ class AuthConfig(FOCABaseConfig): onMethodsEnum.public_key: 'public_key'>], validation_checks=) """ + required: bool = True add_key_to_claims: bool = True allow_expired: bool = False @@ -831,6 +826,7 @@ class CORSConfig(FOCABaseConfig): ... ) CORSConfig(enabled=True) """ + enabled: bool = True @@ -867,6 +863,7 @@ class SecurityConfig(FOCABaseConfig): ', collection_name='access_control_collection', model='/path/to/policy.conf', \ owner_headers={'X-User', 'X-Group'}, user_headers={'X-User'})) """ + access_control: AccessControlConfig = AccessControlConfig() auth: AuthConfig = AuthConfig() cors: CORSConfig = CORSConfig() @@ -899,20 +896,18 @@ class IndexConfig(FOCABaseConfig): IndexConfig(keys=[('name', -1), ('id', 1)], options={'unique': True, '\ sparse': False}) """ + keys: Optional[Union[Dict, List[Tuple]]] = None options: Dict = dict() @field_validator("keys", mode="after") @classmethod def store_enum_value( - cls, - v: Optional[Union[Dict, List[Tuple]]] + cls, v: Optional[Union[Dict, List[Tuple]]] ) -> Optional[Union[Dict, List[Tuple]]]: """Convert dict values of keys into list of tuples""" if v is not None and isinstance(v, dict): - v = [ - tuple([key, val]) for key, val in v.items() - ] + v = [tuple([key, val]) for key, val in v.items()] return v @@ -940,6 +935,7 @@ class CollectionConfig(FOCABaseConfig): CollectionConfig(indexes=[IndexConfig(keys=[('last_name', 1)], options\ ={})], client=None)}, client=None) """ + indexes: Optional[List[IndexConfig]] = None client: Optional[collection.Collection] = None @@ -974,6 +970,7 @@ class DBConfig(FOCABaseConfig): DBConfig(collections={'my_collection': CollectionConfig(indexes=[Index\ Config(keys=[('last_name', 1)], options={})], client=None)}, client=None) """ + collections: Optional[Dict[str, CollectionConfig]] = None client: Optional[database.Database] = None @@ -1005,6 +1002,7 @@ class MongoConfig(FOCABaseConfig): ... ) MongoConfig(host='mongodb', port=27017, dbs=None) """ + host: str = "mongodb" port: int = 27017 dbs: Optional[Dict[str, DBConfig]] = None @@ -1034,14 +1032,15 @@ class JobsConfig(FOCABaseConfig): >>> JobsConfig( ... host="rabbitmq", ... port=5672, - ... backend='rpc://', + ... backend="rpc://", ... include=[], ... ) JobsConfig(host='rabbitmq', port=5672, backend='rpc://', include=[]) """ + host: str = "rabbitmq" port: int = 5672 - backend: str = 'rpc://' + backend: str = "rpc://" include: Optional[List[str]] = None @@ -1070,6 +1069,7 @@ class LogFormatterConfig(FOCABaseConfig): LogFormatterConfig(class_formatter='logging.Formatter', style='{', for\ mat='[{asctime}: {levelname:<8}] {message} [{name}]') """ + class_formatter: str = Field( "logging.Formatter", alias="class", @@ -1106,6 +1106,7 @@ class LogHandlerConfig(FOCABaseConfig): LogHandlerConfig(class_handler='logging.StreamHandler', level=20, form\ atter='standard', stream='ext://sys.stderr') """ + class_handler: str = Field( "logging.StreamHandler", alias="class", @@ -1114,7 +1115,7 @@ class LogHandlerConfig(FOCABaseConfig): formatter: str = "standard" stream: str = "ext://sys.stderr" - _validate_level = field_validator('level')(_validate_log_level_choices) + _validate_level = field_validator("level")(_validate_log_level_choices) class LogRootConfig(FOCABaseConfig): @@ -1139,10 +1140,11 @@ class LogRootConfig(FOCABaseConfig): ... ) LogRootConfig(level=20, handlers=['console']) """ + level: int = 10 handlers: Optional[List[str]] = ["console"] - _validate_level = field_validator('level')(_validate_log_level_choices) + _validate_level = field_validator("level")(_validate_log_level_choices) class LogConfig(FOCABaseConfig): @@ -1195,6 +1197,7 @@ class LogConfig(FOCABaseConfig): ndard', stream='ext://sys.stderr')}, root=LogRootConfig(level=10, handlers=['c\ onsole'])) """ + version: int = 1 disable_existing_loggers: bool = False formatters: Optional[Dict[str, LogFormatterConfig]] = { @@ -1267,6 +1270,7 @@ class 'werkzeug.exceptions.BadGateway'>: {'title': 'Bad Gateway', 'status': 50\ tream='ext://sys.stderr')}, root=LogRootConfig(level=10, handlers=['console'])\ )) """ + server: ServerConfig = ServerConfig() exceptions: ExceptionConfig = ExceptionConfig() api: APIConfig = APIConfig() @@ -1274,4 +1278,4 @@ class 'werkzeug.exceptions.BadGateway'>: {'title': 'Bad Gateway', 'status': 50\ db: Optional[MongoConfig] = None jobs: Optional[JobsConfig] = None log: LogConfig = LogConfig() - model_config = ConfigDict(extra='allow') + model_config = ConfigDict(extra="allow") diff --git a/foca/security/access_control/access_control_server.py b/foca/security/access_control/access_control_server.py index e55474a6..0e0fdcbc 100644 --- a/foca/security/access_control/access_control_server.py +++ b/foca/security/access_control/access_control_server.py @@ -1,12 +1,12 @@ -""""Controllers for permission management endpoints.""" +""" "Controllers for permission management endpoints.""" import logging -from typing import (Dict, List) +from typing import Dict, List -from flask import (request, current_app) +from flask import request, current_app from pymongo.collection import Collection -from werkzeug.exceptions import (InternalServerError, NotFound) +from werkzeug.exceptions import InternalServerError, NotFound from foca.utils.logging import log_traffic from foca.errors.exceptions import BadRequest @@ -32,11 +32,10 @@ def postPermission() -> str: rule.get("v2", None), rule.get("v3", None), rule.get("v4", None), - rule.get("v5", None) + rule.get("v5", None), ] permission_id = access_control_adapter.save_policy_line( - ptype=request_json.get("policy_type", None), - rule=permission_data + ptype=request_json.get("policy_type", None), rule=permission_data ) logger.info("New policy added.") return permission_id @@ -64,23 +63,21 @@ def putPermission( if isinstance(request_json, dict): app_config = current_app.config try: - security_conf = \ - app_config.foca.security # type: ignore[attr-defined] - access_control_config = \ - security_conf.access_control # type: ignore[attr-defined] + security_conf = app_config.foca.security # type: ignore[attr-defined] + access_control_config = security_conf.access_control # type: ignore[attr-defined] db_coll_permission: Collection = ( app_config.foca.db.dbs[ # type: ignore[attr-defined] - access_control_config.db_name] - .collections[access_control_config.collection_name].client + access_control_config.db_name + ] + .collections[access_control_config.collection_name] + .client ) permission_data = request_json.get("rule", {}) permission_data["id"] = id permission_data["ptype"] = request_json.get("policy_type", None) db_coll_permission.replace_one( - filter={"id": id}, - replacement=permission_data, - upsert=True + filter={"id": id}, replacement=permission_data, upsert=True ) logger.info("Policy updated.") return id @@ -103,21 +100,21 @@ def getAllPermissions(limit=None) -> List[Dict]: List of permission dicts. """ app_config = current_app.config - access_control_config = \ - app_config.foca.security.access_control # type: ignore[attr-defined] + access_control_config = app_config.foca.security.access_control # type: ignore[attr-defined] db_coll_permission: Collection = ( app_config.foca.db.dbs[ # type: ignore[attr-defined] access_control_config.db_name - ].collections[access_control_config.collection_name].client + ] + .collections[access_control_config.collection_name] + .client ) if not limit: limit = 0 permissions = list( - db_coll_permission.find( - filter={}, - projection={'_id': False} - ).sort([('$natural', -1)]).limit(limit) + db_coll_permission.find(filter={}, projection={"_id": False}) + .sort([("$natural", -1)]) + .limit(limit) ) user_permission_list = [] for _permission in permissions: @@ -126,11 +123,9 @@ def getAllPermissions(limit=None) -> List[Dict]: del _permission["ptype"] del _permission["id"] rule = _permission - user_permission_list.append({ - "policy_type": policy_type, - "rule": rule, - "id": id - }) + user_permission_list.append( + {"policy_type": policy_type, "rule": rule, "id": id} + ) return user_permission_list @@ -147,12 +142,13 @@ def getPermission( Permission data for the given id. """ app_config = current_app.config - access_control_config = \ - app_config.foca.security.access_control # type: ignore[attr-defined] + access_control_config = app_config.foca.security.access_control # type: ignore[attr-defined] db_coll_permission: Collection = ( app_config.foca.db.dbs[ # type: ignore[attr-defined] access_control_config.db_name - ].collections[access_control_config.collection_name].client + ] + .collections[access_control_config.collection_name] + .client ) permission = db_coll_permission.find_one(filter={"id": id}) @@ -163,11 +159,7 @@ def getPermission( id = permission.get("id", None) del permission["ptype"] del permission["id"] - return { - "id": id, - "rule": permission, - "policy_type": policy_type - } + return {"id": id, "rule": permission, "policy_type": policy_type} @log_traffic @@ -183,15 +175,16 @@ def deletePermission( Delete permission identifier. """ app_config = current_app.config - access_control_config = \ - app_config.foca.security.access_control # type: ignore[attr-defined] + access_control_config = app_config.foca.security.access_control # type: ignore[attr-defined] db_coll_permission: Collection = ( app_config.foca.db.dbs[ # type: ignore[attr-defined] access_control_config.db_name - ].collections[access_control_config.collection_name].client + ] + .collections[access_control_config.collection_name] + .client ) - del_obj_permission = db_coll_permission.delete_one({'id': id}) + del_obj_permission = db_coll_permission.delete_one({"id": id}) if del_obj_permission.deleted_count: return id diff --git a/foca/security/access_control/constants.py b/foca/security/access_control/constants.py index eda4d1ed..aa6aed0d 100644 --- a/foca/security/access_control/constants.py +++ b/foca/security/access_control/constants.py @@ -1,5 +1,4 @@ -"""File to store permission based constants. -""" +"""File to store permission based constants.""" DEFAULT_ACCESS_CONTROL_DB_NAME = "access_control_db" DEFAULT_ACESS_CONTROL_COLLECTION_NAME = "policy_rules" diff --git a/foca/security/access_control/foca_casbin_adapter/adapter.py b/foca/security/access_control/foca_casbin_adapter/adapter.py index a6f0a0c7..8be49ce9 100644 --- a/foca/security/access_control/foca_casbin_adapter/adapter.py +++ b/foca/security/access_control/foca_casbin_adapter/adapter.py @@ -2,12 +2,10 @@ from casbin import persist from casbin.model import Model -from typing import (List, Optional) +from typing import List, Optional from pymongo import MongoClient -from foca.security.access_control.foca_casbin_adapter.casbin_rule import ( - CasbinRule -) +from foca.security.access_control.foca_casbin_adapter.casbin_rule import CasbinRule from foca.utils.misc import generate_id @@ -94,10 +92,7 @@ def _delete_policy_lines(self, ptype: str, rule: List[str]) -> int: else: line_dict = line.dict() line_dict_keys_len = len(line_dict) - results = self._collection.find( - filter=line_dict, - projection={"id": False} - ) + results = self._collection.find(filter=line_dict, projection={"id": False}) to_delete = [ result["_id"] for result in results @@ -153,8 +148,7 @@ def remove_policy(self, sec: str, ptype: str, rule: List[str]): return deleted_count > 0 def remove_filtered_policy( - self, sec: str, ptype: str, - field_index: int, *field_values: List[str] + self, sec: str, ptype: str, field_index: int, *field_values: List[str] ): """Remove policy rules that match the filter from the storage. This is part of the Auto-Save feature. diff --git a/foca/security/access_control/foca_casbin_adapter/casbin_rule.py b/foca/security/access_control/foca_casbin_adapter/casbin_rule.py index 7c91c65a..2de904d7 100644 --- a/foca/security/access_control/foca_casbin_adapter/casbin_rule.py +++ b/foca/security/access_control/foca_casbin_adapter/casbin_rule.py @@ -1,6 +1,6 @@ """Casbin rule class.""" -from typing import (Dict, Optional) +from typing import Dict, Optional class CasbinRule: @@ -33,7 +33,7 @@ def __init__( v2: Optional[str] = None, v3: Optional[str] = None, v4: Optional[str] = None, - v5: Optional[str] = None + v5: Optional[str] = None, ): """Casbin rule object initializer.""" self.ptype = ptype diff --git a/foca/security/access_control/register_access_control.py b/foca/security/access_control/register_access_control.py index ffd1894a..59ca6243 100644 --- a/foca/security/access_control/register_access_control.py +++ b/foca/security/access_control/register_access_control.py @@ -4,7 +4,7 @@ from functools import wraps from importlib.resources import path as resource_path from pathlib import Path -from typing import (Callable, Optional, Tuple) +from typing import Callable, Optional, Tuple from connexion import App from connexion.exceptions import Forbidden @@ -17,13 +17,13 @@ MongoConfig, SpecConfig, CollectionConfig, - AccessControlConfig + AccessControlConfig, ) from foca.database.register_mongodb import add_new_database from foca.security.access_control.foca_casbin_adapter.adapter import Adapter from foca.security.access_control.constants import ( ACCESS_CONTROL_BASE_PATH, - DEFAULT_API_SPEC_PATH + DEFAULT_API_SPEC_PATH, ) logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def register_access_control( cnx_app: App, mongo_config: Optional[MongoConfig], - access_control_config: AccessControlConfig + access_control_config: AccessControlConfig, ) -> App: """Register access control configuration with flask app. @@ -50,10 +50,10 @@ def register_access_control( """ # Register access control database and collection. access_db_conf = DBConfig( - collections={ - access_control_config.collection_name: CollectionConfig() - } if access_control_config.collection_name is not None else {}, - client=None + collections={access_control_config.collection_name: CollectionConfig()} + if access_control_config.collection_name is not None + else {}, + client=None, ) # Set default db attributes if config not present. @@ -73,28 +73,24 @@ def register_access_control( app=cnx_app.app, conf=mongo_config, db_conf=access_db_conf, - db_name=access_control_db + db_name=access_control_db, ) # Register access control api specs and corresponding controller. cnx_app = register_casbin_enforcer( app=cnx_app, mongo_config=mongo_config, - access_control_config=access_control_config + access_control_config=access_control_config, ) cnx_app = register_permission_specs( - app=cnx_app, - access_control_config=access_control_config + app=cnx_app, access_control_config=access_control_config ) return cnx_app -def register_permission_specs( - app: App, - access_control_config: AccessControlConfig -): +def register_permission_specs(app: App, access_control_config: AccessControlConfig): """Register open api specs for permission management. Args: @@ -108,9 +104,7 @@ def register_permission_specs( """ # Check if default, get package path variables for specs. if access_control_config.api_specs is None: - with resource_path( - ACCESS_CONTROL_BASE_PATH, DEFAULT_API_SPEC_PATH - ) as _path: + with resource_path(ACCESS_CONTROL_BASE_PATH, DEFAULT_API_SPEC_PATH) as _path: spec_path = str(_path) else: spec_path = access_control_config.api_specs @@ -118,18 +112,13 @@ def register_permission_specs( spec = SpecConfig( path=Path(spec_path), add_operation_fields={ - "x-openapi-router-controller": ( - access_control_config.api_controllers - ) + "x-openapi-router-controller": (access_control_config.api_controllers) }, connexion={ "strict_validation": True, "validate_responses": True, - "options": { - "swagger_ui": True, - "serve_spec": True - } - } + "options": {"swagger_ui": True, "serve_spec": True}, + }, ) app.add_api( @@ -140,9 +129,7 @@ def register_permission_specs( def register_casbin_enforcer( - app: App, - access_control_config: AccessControlConfig, - mongo_config: MongoConfig + app: App, access_control_config: AccessControlConfig, mongo_config: MongoConfig ) -> App: """Method to add casbin permission enforcer. @@ -164,20 +151,16 @@ def register_casbin_enforcer( app.app.config["CASBIN_MODEL"] = casbin_model logger.info("Setting headers for owner operations.") - app.app.config["CASBIN_OWNER_HEADERS"] = ( - access_control_config.owner_headers - ) + app.app.config["CASBIN_OWNER_HEADERS"] = access_control_config.owner_headers logger.info("Setting headers for user operations.") - app.app.config["CASBIN_USER_NAME_HEADERS"] = ( - access_control_config.user_headers - ) + app.app.config["CASBIN_USER_NAME_HEADERS"] = access_control_config.user_headers logger.info("Setting up casbin enforcer.") adapter = Adapter( uri=f"mongodb://{mongo_config.host}:{mongo_config.port}/", dbname=str(access_control_config.db_name), - collection=access_control_config.collection_name + collection=access_control_config.collection_name, ) app.app.config["casbin_adapter"] = adapter @@ -204,6 +187,7 @@ def _decorator_check_permissions(fn): Returns: The response returned from the input function. """ + @wraps(fn) def _wrapper(*args, **kwargs) -> Tuple[Response, int]: """Wrapper for permissions decorator. @@ -219,11 +203,12 @@ def _wrapper(*args, **kwargs) -> Tuple[Response, int]: """ adapter = current_app.config["casbin_adapter"] casbin_enforcer = CasbinEnforcer(current_app, adapter) - response: Tuple[Response, int] = casbin_enforcer.enforcer( - func=fn - )(*args, **kwargs) + response: Tuple[Response, int] = casbin_enforcer.enforcer(func=fn)( + *args, **kwargs + ) if ( - len(response) == 2 and response[0].status_code == 200 + len(response) == 2 + and response[0].status_code == 200 and response[1] == 401 ): raise Forbidden diff --git a/foca/security/auth.py b/foca/security/auth.py index d0f3468c..6a681bf1 100644 --- a/foca/security/auth.py +++ b/foca/security/auth.py @@ -2,7 +2,7 @@ from connexion.exceptions import Unauthorized import logging -from typing import (Dict, Iterable, List, Optional) +from typing import Dict, Iterable, List, Optional from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey @@ -32,8 +32,8 @@ def validate_token(token: str) -> Dict: # Set parameters defined by OpenID Connect specification # Cf. https://openid.net/specs/openid-connect-discovery-1_0.html oidc_suffix_config: str = ".well-known/openid-configuration" - oidc_config_claim_userinfo: str = 'userinfo_endpoint' - oidc_config_claim_public_keys: str = 'jwks_uri' + oidc_config_claim_userinfo: str = "userinfo_endpoint" + oidc_config_claim_public_keys: str = "jwks_uri" # Fetch security parameters conf = current_app.config.foca.security.auth # type: ignore[attr-defined] @@ -49,16 +49,13 @@ def validate_token(token: str) -> Dict: # Ensure that validation methods are configured if not len(validation_methods): raise Unauthorized( - "Authentication is enabled, but no JWT validation methods " - "configured" + "Authentication is enabled, but no JWT validation methods configured" ) # Decode JWT try: claims = jwt.decode( - jwt=token, - algorithms=algorithms, - options={"verify_signature": False} + jwt=token, algorithms=algorithms, options={"verify_signature": False} ) except Exception as e: raise Unauthorized("JWT could not be decoded") from e @@ -66,9 +63,7 @@ def validate_token(token: str) -> Dict: # Verify existence of issuer claim if claim_issuer not in claims: - raise Unauthorized( - f"Required identity claim not available: {claim_identity}" - ) + raise Unauthorized(f"Required identity claim not available: {claim_identity}") # Get OIDC configuration url = f"{claims[claim_issuer].rstrip('/')}/{oidc_suffix_config}" @@ -77,21 +72,19 @@ def validate_token(token: str) -> Dict: oidc_config = requests.get(url) oidc_config.raise_for_status() except Exception as e: - raise Unauthorized( - f"Could not fetch issuer's configuration from: {url}" - ) from e + raise Unauthorized(f"Could not fetch issuer's configuration from: {url}") from e # Validate token passed_any = False for method in validation_methods: logger.debug(f"Validating JWT via method: {method}") try: - if method == 'userinfo': + if method == "userinfo": _validate_jwt_userinfo( url=oidc_config.json()[oidc_config_claim_userinfo], token=token, ) - if method == 'public_key': + if method == "public_key": _validate_jwt_public_key( url=oidc_config.json()[oidc_config_claim_public_keys], token=token, @@ -101,22 +94,20 @@ def validate_token(token: str) -> Dict: allow_expired=allow_expired, ) except Exception as e: - if validation_checks == 'all': + if validation_checks == "all": raise Unauthorized( "Insufficient number of JWT validation checks passed" ) from e continue passed_any = True - if validation_checks == 'any': + if validation_checks == "any": break if not passed_any: raise Unauthorized("No JWT validation checks passed") # Verify existence of specified identity claim if claim_identity not in claims: - raise Unauthorized( - f"Required identity claim '{claim_identity} not available" - ) + raise Unauthorized(f"Required identity claim '{claim_identity} not available") # Log result logger.debug(f"Access granted to user: {claims[claim_identity]}") @@ -124,24 +115,23 @@ def validate_token(token: str) -> Dict: req_headers = request.headers.__dict__ for key, val in claims.items(): req_headers[key] = val - req_headers['user_id'] = claims[claim_identity] - request.headers = \ - ImmutableMultiDict(req_headers) # type: ignore[assignment] + req_headers["user_id"] = claims[claim_identity] + request.headers = ImmutableMultiDict(req_headers) # type: ignore[assignment] # Return token info return { - 'jwt': token, - 'claims': claims, - 'user_id': claims[claim_identity], - 'scope': claims.get('scope', ""), + "jwt": token, + "claims": claims, + "user_id": claims[claim_identity], + "scope": claims.get("scope", ""), } def _validate_jwt_userinfo( token: str, url: str, - header_name: str = 'Authorization', - prefix: str = 'Bearer', + header_name: str = "Authorization", + prefix: str = "Bearer", ) -> None: """Validate JSON Web Token (JWT) via an OpenID Connect-compliant identity provider's user info endpoint. @@ -173,11 +163,11 @@ def _validate_jwt_userinfo( def _validate_jwt_public_key( token: str, url: str, - algorithms: List[str] = ['RS256'], + algorithms: List[str] = ["RS256"], add_key_to_claims: bool = True, audience: Optional[Iterable[str]] = None, allow_expired: bool = False, - claim_key_id: str = 'kid', + claim_key_id: str = "kid", ) -> None: """Validate JSON Web Token (JWT) via an OpenID Connect-compliant identity provider's public key. @@ -240,9 +230,9 @@ def _validate_jwt_public_key( # Set validations validation_options = {} if audience is None: - validation_options['verify_aud'] = False + validation_options["verify_aud"] = False if allow_expired: - validation_options['verify_exp'] = False + validation_options["verify_exp"] = False # Try public keys one after the other used_key: Dict = {} @@ -280,7 +270,7 @@ def _validate_jwt_public_key( # Add public key to claims if add_key_to_claims: - claims['public_key'] = used_key + claims["public_key"] = used_key # Log success and return claims logger.debug("Validation via issuer's public keys succeeded") @@ -289,8 +279,8 @@ def _validate_jwt_public_key( def _get_public_keys( url: str, pem: bool = False, - claim_key_id: str = 'kid', - claim_keys: str = 'keys', + claim_key_id: str = "kid", + claim_keys: str = "keys", ) -> Dict[str, RSAPublicKey]: """Obtain the identity provider's public JSON Web Key (JWK) set. @@ -330,16 +320,20 @@ def _get_public_keys( # Convert to PEM if requested if pem: - key = key.public_bytes( # type: ignore - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ).decode('utf-8').encode('unicode_escape').decode('utf-8') + key = ( + key.public_bytes( # type: ignore + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("utf-8") + .encode("unicode_escape") + .decode("utf-8") + ) public_keys[jwk[claim_key_id]] = key except Exception as e: logger.warning( - f"JSON Web Key '{jwk}' could not be processed: " - f"{type(e).__name__}: {e}" + f"JSON Web Key '{jwk}' could not be processed: {type(e).__name__}: {e}" ) # Return dictionary of public keys diff --git a/foca/security/cors.py b/foca/security/cors.py index bf110178..241840a5 100644 --- a/foca/security/cors.py +++ b/foca/security/cors.py @@ -17,4 +17,4 @@ def enable_cors(app: Flask) -> None: app: Flask application instance. """ CORS(app) - logger.debug('Enabled CORS for Flask app.') + logger.debug("Enabled CORS for Flask app.") diff --git a/foca/utils/db.py b/foca/utils/db.py index e44604c5..5d6cba4a 100644 --- a/foca/utils/db.py +++ b/foca/utils/db.py @@ -1,6 +1,6 @@ """Utility functions for interacting with a MongoDB database collection.""" -from typing import (Any, Mapping, Optional) +from typing import Any, Mapping, Optional from bson.objectid import ObjectId from pymongo.collection import Collection @@ -17,10 +17,7 @@ def find_one_latest(collection: Collection) -> Optional[Mapping[Any, Any]]: Newest document or ``None``, if no document exists. """ try: - return collection.find( - {}, - {'_id': False} - ).sort([('_id', -1)]).limit(1).next() + return collection.find({}, {"_id": False}).sort([("_id", -1)]).limit(1).next() except StopIteration: return None @@ -36,6 +33,6 @@ def find_id_latest(collection: Collection) -> Optional[ObjectId]: `ObjectId` of newest document or ``None``, if no document exists. """ try: - return collection.find().sort([('_id', -1)]).limit(1).next()['_id'] + return collection.find().sort([("_id", -1)]).limit(1).next()["_id"] except StopIteration: return None diff --git a/foca/utils/logging.py b/foca/utils/logging.py index 2a47aaa9..0e4687e0 100644 --- a/foca/utils/logging.py +++ b/foca/utils/logging.py @@ -3,7 +3,7 @@ import logging from connexion import request from functools import wraps -from typing import (Callable, Optional) +from typing import Callable, Optional logger = logging.getLogger(__name__) @@ -35,6 +35,7 @@ def _decorator_log_traffic(fn): Returns: The response returned from the input function. """ + @wraps(fn) def _wrapper(*args, **kwargs): """Wrapper for logging decorator. @@ -47,9 +48,9 @@ def _wrapper(*args, **kwargs): Wrapper function. """ req = ( - f"\"{request.environ['REQUEST_METHOD']} " + f'"{request.environ["REQUEST_METHOD"]} ' f"{request.environ['PATH_INFO']} " - f"{request.environ['SERVER_PROTOCOL']}\" from " + f'{request.environ["SERVER_PROTOCOL"]}" from ' f"{request.environ['REMOTE_ADDR']}" ) if log_request: diff --git a/foca/utils/misc.py b/foca/utils/misc.py index 599358f4..31609ae3 100644 --- a/foca/utils/misc.py +++ b/foca/utils/misc.py @@ -5,8 +5,7 @@ def generate_id( - charset: str = ''.join([string.ascii_letters, string.digits]), - length: int = 6 + charset: str = "".join([string.ascii_letters, string.digits]), length: int = 6 ) -> str: """Generate random string composed of specified character set. @@ -30,12 +29,8 @@ def generate_id( except Exception as e: raise TypeError(f"Could not evaluate 'charset': {charset}") from e if not isinstance(charset, str) or charset == "": - raise TypeError( - f"Could not evaluate 'charset' to non-empty string: {charset}" - ) + raise TypeError(f"Could not evaluate 'charset' to non-empty string: {charset}") if not isinstance(length, int) or not length > 0: - raise TypeError( - f"Argument to 'length' is not a positive integer: {length}" - ) - charset = ''.join(sorted(set(charset))) - return ''.join(choice(charset) for __ in range(length)) + raise TypeError(f"Argument to 'length' is not a positive integer: {length}") + charset = "".join(sorted(set(charset))) + return "".join(choice(charset) for __ in range(length)) diff --git a/foca/version.py b/foca/version.py index 4823bfd4..10652502 100644 --- a/foca/version.py +++ b/foca/version.py @@ -1,3 +1,3 @@ """Single source of truth for package version.""" -__version__ = '0.13.0' +__version__ = "0.13.0" diff --git a/poetry.lock b/poetry.lock index dbfda3a7..817c69ee 100644 --- a/poetry.lock +++ b/poetry.lock @@ -19,25 +19,11 @@ description = "A light, configurable Sphinx theme" optional = false python-versions = ">=3.9" groups = ["docs"] -markers = "python_version < \"3.11\"" files = [ {file = "alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92"}, {file = "alabaster-0.7.16.tar.gz", hash = "sha256:75a8b99c28a5dad50dd7f8ccdd447a121ddb3892da9e53d1ca5cca3106d58d65"}, ] -[[package]] -name = "alabaster" -version = "1.0.0" -description = "A light, configurable Sphinx theme" -optional = false -python-versions = ">=3.10" -groups = ["docs"] -markers = "python_version >= \"3.11\"" -files = [ - {file = "alabaster-1.0.0-py3-none-any.whl", hash = "sha256:fc6786402dc3fcb2de3cabd5fe455a2db534b371124f1f21de8731783dec828b"}, - {file = "alabaster-1.0.0.tar.gz", hash = "sha256:c00dca57bca26fa62a6d7d0a9fcce65f3e026e9bfe33e9c538fd3fbb2144fd9e"}, -] - [[package]] name = "amqp" version = "5.3.1" @@ -65,21 +51,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[[package]] -name = "astroid" -version = "3.3.8" -description = "An abstract syntax tree for Python with inference support." -optional = false -python-versions = ">=3.9.0" -groups = ["code_quality"] -files = [ - {file = "astroid-3.3.8-py3-none-any.whl", hash = "sha256:187ccc0c248bfbba564826c26f070494f7bc964fd286b6d9fff4420e55de828c"}, - {file = "astroid-3.3.8.tar.gz", hash = "sha256:a88c7994f914a4ea8572fac479459f4955eeccc877be3f2d959a33273b0cf40b"}, -] - -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} - [[package]] name = "attrs" version = "25.1.0" @@ -483,12 +454,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main", "code_quality", "docs", "test"] +groups = ["main", "docs", "test"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "platform_system == \"Windows\"", code_quality = "sys_platform == \"win32\"", docs = "sys_platform == \"win32\"", test = "sys_platform == \"win32\""} +markers = {main = "platform_system == \"Windows\"", docs = "sys_platform == \"win32\"", test = "sys_platform == \"win32\""} [[package]] name = "connexion" @@ -651,22 +622,6 @@ ssh = ["bcrypt (>=3.1.5)"] test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] -[[package]] -name = "dill" -version = "0.3.9" -description = "serialize all of Python" -optional = false -python-versions = ">=3.8" -groups = ["code_quality"] -files = [ - {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, - {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, -] - -[package.extras] -graph = ["objgraph (>=1.7.2)"] -profile = ["gprof2dot (>=2022.7.29)"] - [[package]] name = "dnspython" version = "2.7.0" @@ -716,41 +671,6 @@ files = [ [package.extras] test = ["pytest (>=6)"] -[[package]] -name = "flake8" -version = "7.0.0" -description = "the modular source code checker: pep8 pyflakes and co" -optional = false -python-versions = ">=3.8.1" -groups = ["code_quality"] -files = [ - {file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"}, - {file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"}, -] - -[package.dependencies] -mccabe = ">=0.7.0,<0.8.0" -pycodestyle = ">=2.11.0,<2.12.0" -pyflakes = ">=3.2.0,<3.3.0" - -[[package]] -name = "flake8-pyproject" -version = "1.2.3" -description = "Flake8 plug-in loading the configuration from pyproject.toml" -optional = false -python-versions = ">= 3.6" -groups = ["code_quality"] -files = [ - {file = "flake8_pyproject-1.2.3-py3-none-any.whl", hash = "sha256:6249fe53545205af5e76837644dc80b4c10037e73a0e5db87ff562d75fb5bd4a"}, -] - -[package.dependencies] -Flake8 = ">=5" -TOMLi = {version = "*", markers = "python_version < \"3.11\""} - -[package.extras] -dev = ["pyTest", "pyTest-cov"] - [[package]] name = "flask" version = "2.2.5" @@ -898,22 +818,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "isort" -version = "6.0.0" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.9.0" -groups = ["code_quality"] -files = [ - {file = "isort-6.0.0-py3-none-any.whl", hash = "sha256:567954102bb47bb12e0fae62606570faacddd441e45683968c8d1734fb1af892"}, - {file = "isort-6.0.0.tar.gz", hash = "sha256:75d9d8a1438a9432a7d7b54f2d3b45cad9a4a0fdba43617d9873379704a8bdf1"}, -] - -[package.extras] -colors = ["colorama"] -plugins = ["setuptools"] - [[package]] name = "itsdangerous" version = "2.2.0" @@ -1087,18 +991,6 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] -[[package]] -name = "mccabe" -version = "0.7.0" -description = "McCabe checker, plugin for flake8" -optional = false -python-versions = ">=3.6" -groups = ["code_quality"] -files = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, -] - [[package]] name = "mongomock" version = "4.3.0" @@ -1198,23 +1090,6 @@ files = [ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] -[[package]] -name = "platformdirs" -version = "4.3.6" -description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." -optional = false -python-versions = ">=3.8" -groups = ["code_quality"] -files = [ - {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, - {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, -] - -[package.extras] -docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] -type = ["mypy (>=1.11.2)"] - [[package]] name = "pluggy" version = "1.5.0" @@ -1246,18 +1121,6 @@ files = [ [package.dependencies] wcwidth = "*" -[[package]] -name = "pycodestyle" -version = "2.11.1" -description = "Python style guide checker" -optional = false -python-versions = ">=3.8" -groups = ["code_quality"] -files = [ - {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, - {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, -] - [[package]] name = "pycparser" version = "2.22" @@ -1405,18 +1268,6 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" -[[package]] -name = "pyflakes" -version = "3.2.0" -description = "passive checker of Python programs" -optional = false -python-versions = ">=3.8" -groups = ["code_quality"] -files = [ - {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, - {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, -] - [[package]] name = "pygments" version = "2.19.1" @@ -1450,37 +1301,6 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] -[[package]] -name = "pylint" -version = "3.3.4" -description = "python code static checker" -optional = false -python-versions = ">=3.9.0" -groups = ["code_quality"] -files = [ - {file = "pylint-3.3.4-py3-none-any.whl", hash = "sha256:289e6a1eb27b453b08436478391a48cd53bb0efb824873f949e709350f3de018"}, - {file = "pylint-3.3.4.tar.gz", hash = "sha256:74ae7a38b177e69a9b525d0794bd8183820bfa7eb68cc1bee6e8ed22a42be4ce"}, -] - -[package.dependencies] -astroid = ">=3.3.8,<=3.4.0-dev0" -colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, - {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version == \"3.11\""}, -] -isort = ">=4.2.5,<5.13.0 || >5.13.0,<7" -mccabe = ">=0.6,<0.8" -platformdirs = ">=2.2.0" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -tomlkit = ">=0.10.1" -typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} - -[package.extras] -spelling = ["pyenchant (>=3.2,<4.0)"] -testutils = ["gitpython (>3)"] - [[package]] name = "pymongo" version = "4.11.1" @@ -1730,23 +1550,6 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] -[[package]] -name = "roman-numerals-py" -version = "3.0.0" -description = "Manipulate well-formed Roman numerals" -optional = false -python-versions = ">=3.9" -groups = ["docs"] -markers = "python_version >= \"3.11\"" -files = [ - {file = "roman_numerals_py-3.0.0-py3-none-any.whl", hash = "sha256:a1421ce66b3eab7e8735065458de3fa5c4a46263d50f9f4ac8f0e5e7701dd125"}, - {file = "roman_numerals_py-3.0.0.tar.gz", hash = "sha256:91199c4373658c03d87d9fe004f4a5120a20f6cb192be745c2377cce274ef41c"}, -] - -[package.extras] -lint = ["mypy (==1.15.0)", "pyright (==1.1.394)", "ruff (==0.9.6)"] -test = ["pytest (>=8)"] - [[package]] name = "rpds-py" version = "0.22.3" @@ -1860,6 +1663,34 @@ files = [ {file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"}, ] +[[package]] +name = "ruff" +version = "0.9.7" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["code_quality"] +files = [ + {file = "ruff-0.9.7-py3-none-linux_armv6l.whl", hash = "sha256:99d50def47305fe6f233eb8dabfd60047578ca87c9dcb235c9723ab1175180f4"}, + {file = "ruff-0.9.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d59105ae9c44152c3d40a9c40d6331a7acd1cdf5ef404fbe31178a77b174ea66"}, + {file = "ruff-0.9.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f313b5800483770bd540cddac7c90fc46f895f427b7820f18fe1822697f1fec9"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:042ae32b41343888f59c0a4148f103208bf6b21c90118d51dc93a68366f4e903"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87862589373b33cc484b10831004e5e5ec47dc10d2b41ba770e837d4f429d721"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a17e1e01bee0926d351a1ee9bc15c445beae888f90069a6192a07a84af544b6b"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7c1f880ac5b2cbebd58b8ebde57069a374865c73f3bf41f05fe7a179c1c8ef22"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e63fc20143c291cab2841dbb8260e96bafbe1ba13fd3d60d28be2c71e312da49"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91ff963baed3e9a6a4eba2a02f4ca8eaa6eba1cc0521aec0987da8d62f53cbef"}, + {file = "ruff-0.9.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88362e3227c82f63eaebf0b2eff5b88990280fb1ecf7105523883ba8c3aaf6fb"}, + {file = "ruff-0.9.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0372c5a90349f00212270421fe91874b866fd3626eb3b397ede06cd385f6f7e0"}, + {file = "ruff-0.9.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d76b8ab60e99e6424cd9d3d923274a1324aefce04f8ea537136b8398bbae0a62"}, + {file = "ruff-0.9.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0c439bdfc8983e1336577f00e09a4e7a78944fe01e4ea7fe616d00c3ec69a3d0"}, + {file = "ruff-0.9.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:115d1f15e8fdd445a7b4dc9a30abae22de3f6bcabeb503964904471691ef7606"}, + {file = "ruff-0.9.7-py3-none-win32.whl", hash = "sha256:e9ece95b7de5923cbf38893f066ed2872be2f2f477ba94f826c8defdd6ec6b7d"}, + {file = "ruff-0.9.7-py3-none-win_amd64.whl", hash = "sha256:3770fe52b9d691a15f0b87ada29c45324b2ace8f01200fb0c14845e499eb0c2c"}, + {file = "ruff-0.9.7-py3-none-win_arm64.whl", hash = "sha256:b075a700b2533feb7a01130ff656a4ec0d5f340bb540ad98759b8401c32c2037"}, + {file = "ruff-0.9.7.tar.gz", hash = "sha256:643757633417907510157b206e490c3aa11cab0c087c912f60e07fbafa87a4c6"}, +] + [[package]] name = "sentinels" version = "1.0.0" @@ -1914,7 +1745,6 @@ description = "Python documentation generator" optional = false python-versions = ">=3.9" groups = ["docs"] -markers = "python_version < \"3.11\"" files = [ {file = "sphinx-7.4.7-py3-none-any.whl", hash = "sha256:c2419e2135d11f1951cd994d6eb18a1835bd8fdd8429f9ca375dc1f3281bd239"}, {file = "sphinx-7.4.7.tar.gz", hash = "sha256:242f92a7ea7e6c5b406fdc2615413890ba9f699114a9c09192d7dfead2ee9cfe"}, @@ -1945,43 +1775,6 @@ docs = ["sphinxcontrib-websupport"] lint = ["flake8 (>=6.0)", "importlib-metadata (>=6.0)", "mypy (==1.10.1)", "pytest (>=6.0)", "ruff (==0.5.2)", "sphinx-lint (>=0.9)", "tomli (>=2)", "types-docutils (==0.21.0.20240711)", "types-requests (>=2.30.0)"] test = ["cython (>=3.0)", "defusedxml (>=0.7.1)", "pytest (>=8.0)", "setuptools (>=70.0)", "typing_extensions (>=4.9)"] -[[package]] -name = "sphinx" -version = "8.2.0" -description = "Python documentation generator" -optional = false -python-versions = ">=3.11" -groups = ["docs"] -markers = "python_version >= \"3.11\"" -files = [ - {file = "sphinx-8.2.0-py3-none-any.whl", hash = "sha256:3c0a40ff71ace28b316bde7387d93b9249a3688c202181519689b66d5d0aed53"}, - {file = "sphinx-8.2.0.tar.gz", hash = "sha256:5b0067853d6e97f3fa87563e3404ebd008fce03525b55b25da90706764da6215"}, -] - -[package.dependencies] -alabaster = ">=0.7.14" -babel = ">=2.13" -colorama = {version = ">=0.4.6", markers = "sys_platform == \"win32\""} -docutils = ">=0.20,<0.22" -imagesize = ">=1.3" -Jinja2 = ">=3.1" -packaging = ">=23.0" -Pygments = ">=2.17" -requests = ">=2.30.0" -roman-numerals-py = ">=1.0.0" -snowballstemmer = ">=2.2" -sphinxcontrib-applehelp = ">=1.0.7" -sphinxcontrib-devhelp = ">=1.0.6" -sphinxcontrib-htmlhelp = ">=2.0.6" -sphinxcontrib-jsmath = ">=1.0.1" -sphinxcontrib-qthelp = ">=1.0.6" -sphinxcontrib-serializinghtml = ">=1.1.9" - -[package.extras] -docs = ["sphinxcontrib-websupport"] -lint = ["betterproto (==2.0.0b6)", "mypy (==1.15.0)", "pypi-attestations (==0.0.21)", "pyright (==1.1.394)", "pytest (>=8.0)", "ruff (==0.9.6)", "sphinx-lint (>=0.9)", "types-Pillow (==10.2.0.20240822)", "types-Pygments (==2.19.0.20250107)", "types-colorama (==0.4.15.20240311)", "types-defusedxml (==0.7.0.20240218)", "types-docutils (==0.21.0.20241128)", "types-requests (==2.32.0.20241016)", "types-urllib3 (==1.26.25.14)"] -test = ["cython (>=3.0)", "defusedxml (>=0.7.1)", "pytest (>=8.0)", "pytest-xdist[psutil] (>=3.4)", "setuptools (>=70.0)", "typing_extensions (>=4.9)"] - [[package]] name = "sphinx-rtd-theme" version = "3.0.2" @@ -2187,18 +1980,6 @@ files = [ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] -[[package]] -name = "tomlkit" -version = "0.13.2" -description = "Style preserving TOML library" -optional = false -python-versions = ">=3.8" -groups = ["code_quality"] -files = [ - {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, - {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, -] - [[package]] name = "types-pyyaml" version = "6.0.12.20241230" @@ -2370,4 +2151,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.9.0,<4.0.0" -content-hash = "1f1d49c95035b999762397334646627f3eb29101fa9b53ed14e39dfdfd6d1cdb" +content-hash = "7a9c8fcd0984ae40a70abe54858af2b7c9967d743bdd007d02cc149c665c6cbf" diff --git a/pyproject.toml b/pyproject.toml index 94429baa..4d67a6ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,10 +5,6 @@ requires = ["poetry-core"] [tool.mypy] ignore_missing_imports = true -[tool.flake8] -exclude = [".git", ".eggs", "build", ".venv", "venv", "env"] -max-line-length = 79 - [tool.poetry] authors = ["ELIXIR Cloud & AAI"] classifiers = [ @@ -64,11 +60,9 @@ Werkzeug = "~=2.2" optional = true [tool.poetry.group.code_quality.dependencies] -flake8 = ">=6.1" -Flake8-pyproject = ">=1.2.3" mypy = ">=0.991" mypy-extensions = ">=0.4.4" -pylint = ">=3.2" +ruff = ">=0.4.10" [tool.poetry.group.docs] optional = true @@ -95,3 +89,28 @@ types-requests = "*" types-setuptools = "*" types-urllib3 = "*" typing_extensions = ">=4.11" + +[tool.ruff] +indent-width = 4 + +[tool.ruff.format] +docstring-code-format = true +indent-style = "space" +line-ending = "lf" +quote-style = "double" + +# [tool.ruff.lint] +# select = [ +# "B", # flake8-bugbear +# "C90", # mccabe +# "D", # pydocstyle +# "E", # pycodestyle +# "F", # Pyflakes +# "I", # isort +# "PL", # pylint +# "SIM", # flake8-simplify +# "UP", # pyupgrade +# ] + +[tool.ruff.lint.pydocstyle] +convention = "google" \ No newline at end of file diff --git a/tests/api/test_register_openapi.py b/tests/api/test_register_openapi.py index a26de266..1b4ab3bb 100644 --- a/tests/api/test_register_openapi.py +++ b/tests/api/test_register_openapi.py @@ -21,8 +21,12 @@ PATH_SPECS_2_YAML_ADDITION = DIR / "openapi_2_petstore.addition.yaml" PATH_SPECS_3_YAML_ORIGINAL = DIR / "openapi_3_petstore.original.yaml" PATH_SPECS_3_YAML_MODIFIED = DIR / "openapi_3_petstore.modified.yaml" -PATH_SPECS_3_PATHITEMPARAM_YAML_ORIGINAL = DIR / "openapi_3_petstore_pathitemparam.original.yaml" # noqa: E501 -PATH_SPECS_3_PATHITEMPARAM_YAML_MODIFIED = DIR / "openapi_3_petstore_pathitemparam.modified.yaml" # noqa: E501 +PATH_SPECS_3_PATHITEMPARAM_YAML_ORIGINAL = ( + DIR / "openapi_3_petstore_pathitemparam.original.yaml" +) # noqa: E501 +PATH_SPECS_3_PATHITEMPARAM_YAML_MODIFIED = ( + DIR / "openapi_3_petstore_pathitemparam.modified.yaml" +) # noqa: E501 PATH_SPECS_INVALID_JSON = DIR / "invalid.json" PATH_SPECS_INVALID_YAML = DIR / "invalid.openapi.yaml" PATH_NOT_FOUND = DIR / "does/not/exist.yaml" @@ -35,9 +39,7 @@ "info": { "version": "1.0.0", "title": "Swagger Petstore", - "license": { - "name": "MIT" - } + "license": {"name": "MIT"}, } } CONNEXION_CONFIG = { @@ -46,7 +48,7 @@ "options": { "swagger_ui": True, "serve_spec": True, - } + }, } SPEC_CONFIG_2 = { "path": PATH_SPECS_2_YAML_ORIGINAL, @@ -76,19 +78,18 @@ "connexion": CONNEXION_CONFIG, } SPEC_CONFIG_2_JSON = deepcopy(SPEC_CONFIG_2) -SPEC_CONFIG_2_JSON['path'] = PATH_SPECS_2_JSON_ORIGINAL +SPEC_CONFIG_2_JSON["path"] = PATH_SPECS_2_JSON_ORIGINAL SPEC_CONFIG_2_LIST = deepcopy(SPEC_CONFIG_2) -SPEC_CONFIG_2_LIST['path'] = [PATH_SPECS_2_YAML_ORIGINAL] +SPEC_CONFIG_2_LIST["path"] = [PATH_SPECS_2_YAML_ORIGINAL] SPEC_CONFIG_2_MULTI = deepcopy(SPEC_CONFIG_2_LIST) -SPEC_CONFIG_2_MULTI['path'].append(PATH_SPECS_2_YAML_ADDITION) +SPEC_CONFIG_2_MULTI["path"].append(PATH_SPECS_2_YAML_ADDITION) SPEC_CONFIG_2_DISABLE_AUTH = deepcopy(SPEC_CONFIG_2) -SPEC_CONFIG_2_DISABLE_AUTH['disable_auth'] = True +SPEC_CONFIG_2_DISABLE_AUTH["disable_auth"] = True SPEC_CONFIG_3_DISABLE_AUTH = deepcopy(SPEC_CONFIG_3) -SPEC_CONFIG_3_DISABLE_AUTH['disable_auth'] = True +SPEC_CONFIG_3_DISABLE_AUTH["disable_auth"] = True class TestRegisterOpenAPI: - def test_openapi_2_yaml(self): """Successfully register OpenAPI 2 YAML specs with Connexion app.""" app = App(__name__) diff --git a/tests/config/test_config_parser.py b/tests/config/test_config_parser.py index 847b53f6..5064bad1 100644 --- a/tests/config/test_config_parser.py +++ b/tests/config/test_config_parser.py @@ -18,9 +18,9 @@ PATH = str(DIR / "openapi_2_petstore.original.yaml") PATH_ADDITION = str(DIR / "openapi_2_petstore.addition.yaml") TEST_CONFIG_INSTANCE = Config() -TEST_CONFIG_MODEL = 'tests.test_files.model_valid.CustomConfig' -TEST_CONFIG_MODEL_NOT_EXISTS = 'tests.test_files.model_valid.NotExists' -TEST_CONFIG_MODEL_MODULE_NOT_EXISTS = 'tests.test_files.not_a_module.NotExists' +TEST_CONFIG_MODEL = "tests.test_files.model_valid.CustomConfig" +TEST_CONFIG_MODEL_NOT_EXISTS = "tests.test_files.model_valid.NotExists" +TEST_CONFIG_MODEL_MODULE_NOT_EXISTS = "tests.test_files.not_a_module.NotExists" TEST_DICT = {} TEST_FILE = "tests/test_files/conf_valid.yaml" TEST_FILE_CUSTOM_INVALID = "tests/test_files/conf_valid_custom_invalid.yaml" @@ -32,7 +32,7 @@ def test_config_parser_valid_config_file(): """Test valid YAML parsing.""" conf = ConfigParser(Path(TEST_FILE)) - assert type(conf.config.model_dump()) == type(TEST_DICT) + assert type(conf.config.model_dump()) is type(TEST_DICT) assert isinstance(conf.config, type(TEST_CONFIG_INSTANCE)) @@ -52,7 +52,7 @@ def test_config_parser_invalid_file_path(): def test_config_parser_invalid_log_config(): """Test invalid log config YAML.""" conf = ConfigParser(Path(TEST_FILE_INVALID_LOG)) - assert type(conf.config.model_dump()) == type(TEST_DICT) + assert type(conf.config.model_dump()) is type(TEST_DICT) assert isinstance(conf.config, type(TEST_CONFIG_INSTANCE)) @@ -97,7 +97,7 @@ def test_merge_yaml_with_two_args(): """Test merge_yaml with no arguments.""" yaml_list = [Path(PATH), Path(PATH_ADDITION)] res = ConfigParser.merge_yaml(*yaml_list) - assert 'put' in res['paths']['/pets/{petId}'] + assert "put" in res["paths"]["/pets/{petId}"] def test_parse_custom_config_valid_model(): diff --git a/tests/database/test_register_mongodb.py b/tests/database/test_register_mongodb.py index e4f340ef..2b658653 100644 --- a/tests/database/test_register_mongodb.py +++ b/tests/database/test_register_mongodb.py @@ -10,31 +10,26 @@ from foca.models.config import MongoConfig MONGO_DICT_MIN = { - 'host': 'mongodb', - 'port': 27017, -} -DB_DICT_NO_COLL = { - 'my_db': { - 'collections': None - } + "host": "mongodb", + "port": 27017, } +DB_DICT_NO_COLL = {"my_db": {"collections": None}} DB_DICT_DEF_COLL = { - 'my_db': { - 'collections': { - 'my_collection': { - 'indexes': None, + "my_db": { + "collections": { + "my_collection": { + "indexes": None, } } } } DB_DICT_CUST_COLL = { - 'my_db': { - 'collections': { - 'my_collection': { - 'indexes': [{ - 'keys': {'indexed_field': 1}, - 'options': {'sparse': False} - }] + "my_db": { + "collections": { + "my_collection": { + "indexes": [ + {"keys": {"indexed_field": 1}, "options": {"sparse": False}} + ] } } } @@ -47,7 +42,7 @@ def test__create_mongo_client(monkeypatch): """When MONGO_USERNAME environement variable is NOT defined""" - monkeypatch.setenv("MONGO_USERNAME", 'None') + monkeypatch.setenv("MONGO_USERNAME", "None") app = Flask(__name__) res = _create_mongo_client( app=app, @@ -65,7 +60,7 @@ def test__create_mongo_client_auth(monkeypatch): def test__create_mongo_client_auth_empty(monkeypatch): """When MONGO_USERNAME environment variable IS defined but empty""" - monkeypatch.setenv("MONGO_USERNAME", '') + monkeypatch.setenv("MONGO_USERNAME", "") app = Flask(__name__) res = _create_mongo_client(app) assert isinstance(res, PyMongo) @@ -104,11 +99,11 @@ def test_register_mongodb_def_collections(): def test_register_mongodb_cust_collections(monkeypatch): """Register MongoDB with collections and custom indexes""" monkeypatch.setattr( - 'pymongo.collection.Collection.create_index', + "pymongo.collection.Collection.create_index", lambda *args, **kwargs: None, ) monkeypatch.setattr( - 'pymongo.collection.Collection.drop_indexes', + "pymongo.collection.Collection.drop_indexes", lambda *args, **kwargs: None, ) app = Flask(__name__) diff --git a/tests/errors/test_errors.py b/tests/errors/test_errors.py index 93eb91dd..2e6f76f5 100644 --- a/tests/errors/test_errors.py +++ b/tests/errors/test_errors.py @@ -5,7 +5,7 @@ from copy import deepcopy import json -from flask import (Flask, Response) +from flask import Flask, Response from connexion import App import pytest @@ -20,7 +20,7 @@ from foca.models.config import Config EXCEPTION_INSTANCE = Exception() -INVALID_LOG_FORMAT = 'unknown_log_format' +INVALID_LOG_FORMAT = "unknown_log_format" TEST_DICT = { "title": "MyException", "details": { @@ -29,7 +29,7 @@ }, "status": 400, } -TEST_KEYS = ['details', 'code'] +TEST_KEYS = ["details", "code"] EXPECTED_SUBSET_RESULT = { "details": { "code": 400, @@ -42,8 +42,8 @@ }, "status": 400, } -PUBLIC_MEMBERS = [['title']] -PRIVATE_MEMBERS = [['status']] +PUBLIC_MEMBERS = [["title"]] +PRIVATE_MEMBERS = [["status"]] class UnknownException(Exception): @@ -63,7 +63,7 @@ def test__exc_to_str(): assert isinstance(res, str) -@pytest.mark.parametrize("format", ['oneline', 'minimal', 'regular']) +@pytest.mark.parametrize("format", ["oneline", "minimal", "regular"]) def test__log_exception(caplog, format): """Test exception reformatter function.""" _log_exception( @@ -84,45 +84,39 @@ def test__log_exception_invalid_format(caplog): def test__subset_nested_dict(): """Test nested dictionary subsetting function.""" - res = _subset_nested_dict( - obj=TEST_DICT, - key_sequence=deepcopy(TEST_KEYS) - ) + res = _subset_nested_dict(obj=TEST_DICT, key_sequence=deepcopy(TEST_KEYS)) assert res == EXPECTED_SUBSET_RESULT def test__exclude_key_nested_dict(): """Test function to exclude a key from a nested dictionary.""" - res = _exclude_key_nested_dict( - obj=TEST_DICT, - key_sequence=deepcopy(TEST_KEYS) - ) + res = _exclude_key_nested_dict(obj=TEST_DICT, key_sequence=deepcopy(TEST_KEYS)) assert res == EXPECTED_EXCLUDE_RESULT def test__problem_handler_json(): """Test problem handler with instance of custom, unlisted error.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) EXPECTED_RESPONSE = app.config.foca.exceptions.mapping[Exception] with app.app_context(): res = _problem_handler_json(UnknownException()) assert isinstance(res, Response) - assert res.status == '500 INTERNAL SERVER ERROR' + assert res.status == "500 INTERNAL SERVER ERROR" assert res.mimetype == "application/problem+json" - response = json.loads(res.data.decode('utf-8')) + response = json.loads(res.data.decode("utf-8")) assert response == EXPECTED_RESPONSE def test__problem_handler_json_no_fallback_exception(): """Test problem handler; unlisted error without fallback.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) del app.config.foca.exceptions.mapping[Exception] with app.app_context(): res = _problem_handler_json(UnknownException()) assert isinstance(res, Response) - assert res.status == '500 INTERNAL SERVER ERROR' + assert res.status == "500 INTERNAL SERVER ERROR" assert res.mimetype == "application/problem+json" response = res.data.decode("utf-8") assert response == "" @@ -131,22 +125,22 @@ def test__problem_handler_json_no_fallback_exception(): def test__problem_handler_json_with_public_members(): """Test problem handler with public members.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) app.config.foca.exceptions.public_members = PUBLIC_MEMBERS with app.app_context(): res = _problem_handler_json(UnknownException()) assert isinstance(res, Response) - assert res.status == '500 INTERNAL SERVER ERROR' + assert res.status == "500 INTERNAL SERVER ERROR" assert res.mimetype == "application/problem+json" def test__problem_handler_json_with_private_members(): """Test problem handler with private members.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) app.config.foca.exceptions.private_members = PRIVATE_MEMBERS with app.app_context(): res = _problem_handler_json(UnknownException()) assert isinstance(res, Response) - assert res.status == '500 INTERNAL SERVER ERROR' + assert res.status == "500 INTERNAL SERVER ERROR" assert res.mimetype == "application/problem+json" diff --git a/tests/factories/test_celery_app.py b/tests/factories/test_celery_app.py index e46838ef..c722460f 100644 --- a/tests/factories/test_celery_app.py +++ b/tests/factories/test_celery_app.py @@ -4,7 +4,7 @@ from foca.factories.celery_app import create_celery_app from foca.factories.connexion_app import create_connexion_app -from foca.models.config import (Config, JobsConfig) +from foca.models.config import Config, JobsConfig CONFIG = Config() CONFIG.jobs = JobsConfig() diff --git a/tests/factories/test_connexion_app.py b/tests/factories/test_connexion_app.py index 7da4b787..21c29c98 100644 --- a/tests/factories/test_connexion_app.py +++ b/tests/factories/test_connexion_app.py @@ -6,13 +6,13 @@ from foca.factories.connexion_app import ( __add_config_to_connexion_app, create_connexion_app, - ) +) CONFIG = Config() ERROR_CODE = 400 ERROR_ORIGINAL = { - 'title': 'BAD REQUEST', - 'status_code': str(ERROR_CODE), + "title": "BAD REQUEST", + "status_code": str(ERROR_CODE), } ERROR_REWRITTEN = { "msg": "The request is malformed.", diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 6f41ea44..2d1440c9 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -48,7 +48,7 @@ def test_add_pet_extra_parameter_200(): assert isinstance(response_data, Pet) assert response_data.name == NAME_PET assert response_data.tag == TAG_PET - assert getattr(response_data, 'extra_parameter', None) is None + assert getattr(response_data, "extra_parameter", None) is None def test_add_pet_required_arguments_missing_400(): diff --git a/tests/mock_data.py b/tests/mock_data.py index 4b830286..603a4eac 100644 --- a/tests/mock_data.py +++ b/tests/mock_data.py @@ -1,10 +1,9 @@ """Mock data for testing.""" + from pathlib import Path -INDEX_CONFIG = { - "keys": [("id", 1)] -} +INDEX_CONFIG = {"keys": [("id", 1)]} COLLECTION_CONFIG = { "indexes": [INDEX_CONFIG], } @@ -28,25 +27,14 @@ "collection_name": "policy_rules", "owner_headers": ["X-User", "X-Group"], "user_headers": ["X-User"], - "model": MODEL_CONF_FILE + "model": MODEL_CONF_FILE, } MOCK_ID = "mock_id" -MOCK_RULE = { - "ptype": "p1", - "v0": "alice", - "v1": "data1", - "v2": "POST", - "v3": "read" -} +MOCK_RULE = {"ptype": "p1", "v0": "alice", "v1": "data1", "v2": "POST", "v3": "read"} MOCK_RULE_USER_INPUT_OUTPUT = { "policy_type": "p1", - "rule": { - "v0": "alice", - "v1": "data1", - "v2": "POST", - "v3": "read" - } + "rule": {"v0": "alice", "v1": "data1", "v2": "POST", "v3": "read"}, } MOCK_RULE_INVALID = {"rule": []} MOCK_PERMISSION = ["alice", "/", "GET"] @@ -54,5 +42,5 @@ "REQUEST_METHOD": "GET", "PATH_INFO": "/", "SERVER_PROTOCOL": "HTTP/1.1", - "REMOTE_ADDR": "192.168.1.1" + "REMOTE_ADDR": "192.168.1.1", } diff --git a/tests/models/test_config.py b/tests/models/test_config.py index bcb760e7..933a5444 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -18,76 +18,71 @@ DIR = Path(__file__).parent / "test_files" EXCEPTIONS_NO_DICT = [] -EXCEPTIONS_NOT_NESTED = {Exception: 'b'} -EXCEPTIONS_NOT_EXC = {'a': {'status': 400, 'title': 'Bad Request'}} -REQUIRED_MEMBERS = [['title'], ['status']] -MEMBER_TITLE = ['title'] -MEMBER_STATUS = ['status'] -MEMBER_NA = ['some', 'field'] -MEMBERS_NA = [ - MEMBER_NA, - MEMBER_NA + ['more'] -] +EXCEPTIONS_NOT_NESTED = {Exception: "b"} +EXCEPTIONS_NOT_EXC = {"a": {"status": 400, "title": "Bad Request"}} +REQUIRED_MEMBERS = [["title"], ["status"]] +MEMBER_TITLE = ["title"] +MEMBER_STATUS = ["status"] +MEMBER_NA = ["some", "field"] +MEMBERS_NA = [MEMBER_NA, MEMBER_NA + ["more"]] MODULE_NA = "some.unavailable.module" MODULE_WITHOUT_EXCEPTIONS = "foca.foca.exceptions" -MODULE_PATCH_NO_DICT = 'some.path.EXCEPTIONS_NO_DICT' -MODULE_PATCH_NOT_NESTED = 'some.path.EXCEPTIONS_NOT_NESTED' -MODULE_PATCH_NOT_EXC = 'some.path.EXCEPTIONS_NOT_EXC' +MODULE_PATCH_NO_DICT = "some.path.EXCEPTIONS_NO_DICT" +MODULE_PATCH_NOT_NESTED = "some.path.EXCEPTIONS_NOT_NESTED" +MODULE_PATCH_NOT_EXC = "some.path.EXCEPTIONS_NOT_EXC" PATH = str(DIR / "openapi_2_petstore.yaml") PATH_MODIFIED = str(DIR / "openapi_2_petstore.modified.yaml") PATH_ADDITION = str(DIR / "openapi_2_petstore.addition.yaml") INDEX_CONFIG = { - 'keys': {'last_name': -1}, - 'options': { - 'name': 'indexLastName', - 'unique': True, - 'background': False, - 'sparse': False - } + "keys": {"last_name": -1}, + "options": { + "name": "indexLastName", + "unique": True, + "background": False, + "sparse": False, + }, } COLLECTION_CONFIG = { - 'indexes': [INDEX_CONFIG], + "indexes": [INDEX_CONFIG], } DB_CONFIG = { - 'collections': { - 'wes-col': COLLECTION_CONFIG, + "collections": { + "wes-col": COLLECTION_CONFIG, }, } MONGO_CONFIG = { - 'host': 'mongodb', - 'port': 27017, - 'dbs': { - 'wes': DB_CONFIG, + "host": "mongodb", + "port": 27017, + "dbs": { + "wes": DB_CONFIG, }, } SPEC_CONFIG = { - 'path': '/my/abs/path', - 'path_out': '/my/abs/out/path', + "path": "/my/abs/path", + "path_out": "/my/abs/out/path", } SPEC_CONFIG_REL_IN = { - 'path': 'path', - 'path_out': '/my/abs/out/path', + "path": "path", + "path_out": "/my/abs/out/path", } SPEC_CONFIG_REL_OUT = { - 'path': '/my/abs/path', - 'path_out': 'path', + "path": "/my/abs/path", + "path_out": "path", } SPEC_CONFIG_REL_IO = { - 'path': '/my/abs/path', - 'path_out': '/my/abs/out/path', + "path": "/my/abs/path", + "path_out": "/my/abs/out/path", } SPEC_CONFIG_NO_OUT = { - 'path': '/my/abs/path', + "path": "/my/abs/path", } SPEC_CONFIG_REL_IN_NO_OUT = { - 'path': 'path', + "path": "path", } SPEC_CONFIG_NO_IN = { - 'path_out': '/my/abs/out/path', -} -SPEC_CONFIG_LIST_NO_OUT = { - 'path': ['path1', 'path2'] + "path_out": "/my/abs/out/path", } +SPEC_CONFIG_LIST_NO_OUT = {"path": ["path1", "path2"]} def test_config_empty(): @@ -142,8 +137,7 @@ def test_exception_config_with_wrong_exceptions_type(monkeypatch): """Test creation of the ExceptionConfig model; exceptions object is not of dictionary type.""" monkeypatch.setattr( - 'importlib.import_module', - lambda *args, **kwargs: sys.modules[__name__] + "importlib.import_module", lambda *args, **kwargs: sys.modules[__name__] ) with pytest.raises(ValidationError): ExceptionConfig( @@ -157,7 +151,7 @@ def test_exception_config_with_optional_status_member(): with pytest.raises(ValidationError): ExceptionConfig( extension_members=True, - status_member=['sdf', 'asd'], + status_member=["sdf", "asd"], ) @@ -205,8 +199,7 @@ def test_exception_config_with_wrong_exception_values(monkeypatch): """Test creation of the ExceptionConfig model; exceptions object is not a dictionary of dictionaries.""" monkeypatch.setattr( - 'importlib.import_module', - lambda *args, **kwargs: sys.modules[__name__] + "importlib.import_module", lambda *args, **kwargs: sys.modules[__name__] ) with pytest.raises(ValidationError): ExceptionConfig( @@ -218,8 +211,7 @@ def test_exception_config_with_wrong_exception_keys(monkeypatch): """Test creation of the ExceptionConfig model; exceptions object is not a dictionary of exceptions.""" monkeypatch.setattr( - 'importlib.import_module', - lambda *args, **kwargs: sys.modules[__name__] + "importlib.import_module", lambda *args, **kwargs: sys.modules[__name__] ) with pytest.raises(ValidationError): ExceptionConfig( @@ -231,9 +223,7 @@ def test_exception_config_missing_required_members(): """Test creation of the ExceptionConfig model; required members are missing.""" with pytest.raises(ValidationError): - ExceptionConfig( - required_members=REQUIRED_MEMBERS + MEMBERS_NA - ) + ExceptionConfig(required_members=REQUIRED_MEMBERS + MEMBERS_NA) def test_exception_config_forbidden_extension_members(): @@ -366,7 +356,7 @@ def test_spec_config_list_no_out(): def test_SpecConfig_full(): """Test SpecConfig instantiation; full example""" res = SpecConfig(**SPEC_CONFIG) - assert str(res.path_out) == SPEC_CONFIG['path_out'] + assert str(res.path_out) == SPEC_CONFIG["path_out"] def test_SpecConfig_minimal(): diff --git a/tests/security/access_control/foca_casbin_adapter/test_adapter.py b/tests/security/access_control/foca_casbin_adapter/test_adapter.py index 4bb4827f..998e307f 100644 --- a/tests/security/access_control/foca_casbin_adapter/test_adapter.py +++ b/tests/security/access_control/foca_casbin_adapter/test_adapter.py @@ -17,7 +17,7 @@ ("p", "p", ["bob", "data2", "write"]), ("p", "p", ["data2_admin", "data2", "read"]), ("p", "p", ["data2_admin", "data2", "write"]), - ("g", "g", ["alice", "data2_admin"]) + ("g", "g", ["alice", "data2_admin"]), ] TEST_POLICIES_MODEL_ROLES_CONF = [ ("p", "p", ["alice", "data1", "write"]), @@ -25,7 +25,7 @@ ("p", "p", ["bob", "data2", "read"]), ("p", "p", ["data_group_admin", "data_group", "write"]), ("g", "g", ["alice", "data_group_admin"]), - ("g", "g2", ["data2", "data_group"]) + ("g", "g2", ["data2", "data_group"]), ] @@ -43,12 +43,7 @@ def setUp(self): def tearDown(self): self.clear_db() - def save_policies( - self, - adapter: Adapter, - model: Model, - policy: Tuple - ) -> None: + def save_policies(self, adapter: Adapter, model: Model, policy: Tuple) -> None: """Helper function for adding policy to a given model. Args: @@ -65,11 +60,7 @@ def save_policies( model.add_policy(*policy) adapter.save_policy(model) - def get_enforcer( - self, - conf_file: str, - policies: List - ) -> Enforcer: + def get_enforcer(self, conf_file: str, policies: List) -> Enforcer: """Helper function to register policy enforcer. Args: @@ -121,9 +112,7 @@ def test_add_policy(self): policy1 = adapter.add_policy( sec="p", ptype="p", rule=("alice", "data1", "write") ) - policy2 = adapter.add_policy( - sec="p", ptype="p", rule=("bob", "data2", "read") - ) + policy2 = adapter.add_policy(sec="p", ptype="p", rule=("bob", "data2", "read")) e.load_policy() assert policy1 is True @@ -156,8 +145,7 @@ def test_remove_policy_with_incomplete_rule(self): adapter = Adapter(f"mongodb://localhost:{self.db_port}", self.db_name) e = Enforcer(MODEL_ROLES_CONF_FILE, adapter) e = self.get_enforcer( - conf_file=MODEL_ROLES_CONF_FILE, - policies=TEST_POLICIES_MODEL_ROLES_CONF + conf_file=MODEL_ROLES_CONF_FILE, policies=TEST_POLICIES_MODEL_ROLES_CONF ) e.load_policy() @@ -185,16 +173,13 @@ def test_remove_policy_with_empty_rule(self): adapter = Adapter(f"mongodb://localhost:{self.db_port}", self.db_name) e = Enforcer(MODEL_ROLES_CONF_FILE, adapter) e = self.get_enforcer( - conf_file=MODEL_ROLES_CONF_FILE, - policies=TEST_POLICIES_MODEL_ROLES_CONF + conf_file=MODEL_ROLES_CONF_FILE, policies=TEST_POLICIES_MODEL_ROLES_CONF ) e.load_policy() assert e.enforce("alice", "data1", "write") is True - remove_policy = adapter.remove_policy( - sec="p", ptype=None, rule=() - ) + remove_policy = adapter.remove_policy(sec="p", ptype=None, rule=()) e.load_policy() assert remove_policy is False @@ -228,9 +213,7 @@ def test_remove_filtered_policy(self): assert e.enforce("alice", "data2", "read") is True assert e.enforce("alice", "data2", "write") is True - result = adapter.remove_filtered_policy( - "g", "g", 6, "alice", "data2_admin" - ) + result = adapter.remove_filtered_policy("g", "g", 6, "alice", "data2_admin") e.load_policy() assert result is False @@ -240,9 +223,7 @@ def test_remove_filtered_policy(self): e.load_policy() assert result is False - result = adapter.remove_filtered_policy( - "g", "g", 0, "alice", "data2_admin" - ) + result = adapter.remove_filtered_policy("g", "g", 0, "alice", "data2_admin") e.load_policy() assert result is True assert e.enforce("alice", "data1", "read") is True diff --git a/tests/security/access_control/foca_casbin_adapter/test_casbin_rule.py b/tests/security/access_control/foca_casbin_adapter/test_casbin_rule.py index f6c436bf..de606f6f 100644 --- a/tests/security/access_control/foca_casbin_adapter/test_casbin_rule.py +++ b/tests/security/access_control/foca_casbin_adapter/test_casbin_rule.py @@ -1,9 +1,6 @@ -"""Tests for initialising casbin rule object. -""" +"""Tests for initialising casbin rule object.""" -from foca.security.access_control.foca_casbin_adapter.casbin_rule import ( - CasbinRule -) +from foca.security.access_control.foca_casbin_adapter.casbin_rule import CasbinRule # Define data BASE_RULE_OBJECT = { @@ -13,7 +10,7 @@ "v2": "v2", "v3": "v3", "v4": "v4", - "v5": "v5" + "v5": "v5", } BASE_RULE_STR_REPRESENTATION = "ptype, v0, v1, v2, v3, v4, v5" BASE_RULE_REPRESENTATION = f'' diff --git a/tests/security/access_control/test_access_control_server.py b/tests/security/access_control/test_access_control_server.py index 4661952f..c6338caa 100644 --- a/tests/security/access_control/test_access_control_server.py +++ b/tests/security/access_control/test_access_control_server.py @@ -16,14 +16,15 @@ getPermission, getAllPermissions, postPermission, - putPermission + putPermission, ) from foca.security.access_control.foca_casbin_adapter.adapter import Adapter from foca.security.access_control.constants import ( - ACCESS_CONTROL_BASE_PATH, DEFAULT_MODEL_FILE + ACCESS_CONTROL_BASE_PATH, + DEFAULT_MODEL_FILE, ) from foca.errors.exceptions import BadRequest, InternalServerError, NotFound -from foca.models.config import (AccessControlConfig, Config, MongoConfig) +from foca.models.config import AccessControlConfig, Config, MongoConfig from tests.mock_data import ( ACCESS_CONTROL_CONFIG, @@ -31,7 +32,7 @@ MOCK_RULE, MOCK_RULE_USER_INPUT_OUTPUT, MOCK_RULE_INVALID, - MONGO_CONFIG + MONGO_CONFIG, ) @@ -76,10 +77,12 @@ def test_getPermission(self): app.config.foca = base_config mock_resp = deepcopy(MOCK_RULE) mock_resp["id"] = MOCK_ID - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client.insert_one(mock_resp) + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client.insert_one(mock_resp) del mock_resp["_id"] data = deepcopy(MOCK_RULE_USER_INPUT_OUTPUT) @@ -100,10 +103,12 @@ def test_getPermission_NotFound(self): app.config.foca = base_config mock_resp = deepcopy(MOCK_RULE) mock_resp["id"] = MOCK_ID - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client.insert_one(mock_resp) + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client.insert_one(mock_resp) del mock_resp["_id"] with app.app_context(): @@ -130,10 +135,12 @@ def test_deletePermission(self): app.config.foca = base_config mock_resp = deepcopy(MOCK_RULE) mock_resp["id"] = MOCK_ID - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client.insert_one(mock_resp) + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client.insert_one(mock_resp) with app.app_context(): res = deletePermission.__wrapped__(id=MOCK_ID) @@ -148,10 +155,12 @@ def test_deletePermission_NotFound(self): ) app.config.foca = base_config mock_resp = deepcopy(MOCK_RULE) - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client.insert_one(mock_resp) + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client.insert_one(mock_resp) with app.app_context(): with pytest.raises(NotFound): @@ -180,14 +189,16 @@ def test_getAllPermissions(self): ) app.config.foca = base_config mock_resp = deepcopy(MOCK_RULE) - mock_resp['id'] = MOCK_ID - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client.insert_one(mock_resp) + mock_resp["id"] = MOCK_ID + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client.insert_one(mock_resp) data = deepcopy(MOCK_RULE_USER_INPUT_OUTPUT) - data['id'] = MOCK_ID + data["id"] = MOCK_ID with app.app_context(): res = getAllPermissions.__wrapped__() assert res == [data] @@ -204,14 +215,16 @@ def test_getAllPermissions_filters(self): ) app.config.foca = base_config mock_resp = deepcopy(MOCK_RULE) - mock_resp['id'] = MOCK_ID - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client.insert_one(mock_resp) + mock_resp["id"] = MOCK_ID + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client.insert_one(mock_resp) data = deepcopy(MOCK_RULE_USER_INPUT_OUTPUT) - data['id'] = MOCK_ID + data["id"] = MOCK_ID with app.app_context(): res = getAllPermissions.__wrapped__(limit=1) assert res == [data] @@ -235,12 +248,13 @@ def test_postPermission(self): base_config = Config(db=self.db) base_config.security.access_control = self.access_control app.config.foca = base_config - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection app.config["casbin_adapter"] = Adapter( uri=f"mongodb://localhost:{self.db_port}/", dbname=self.access_db, - collection=self.access_col + collection=self.access_col, ) with app.test_request_context(json=deepcopy(MOCK_RULE)): @@ -253,12 +267,13 @@ def test_postPermission_InternalServerError(self): base_config = Config(db=self.db) base_config.security.access_control = self.access_control app.config.foca = base_config - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection app.config["casbin_adapter"] = Adapter( uri=f"mongodb://localhost:{self.db_port}/", dbname=self.access_db, - collection=self.access_col + collection=self.access_col, ) with app.test_request_context(json=deepcopy(MOCK_RULE_INVALID)): @@ -271,12 +286,13 @@ def test_postPermission_BadRequest(self): base_config = Config(db=self.db) base_config.security.access_control = self.access_control app.config.foca = base_config - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection app.config["casbin_adapter"] = Adapter( uri=f"mongodb://localhost:{self.db_port}/", dbname=self.access_db, - collection=self.access_col + collection=self.access_col, ) with app.test_request_context(json=""): @@ -304,8 +320,9 @@ def test_putPermission(self): **ACCESS_CONTROL_CONFIG ) app.config.foca = base_config - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection with app.test_request_context(json=deepcopy(MOCK_RULE)): res = putPermission.__wrapped__(id=MOCK_ID) @@ -320,8 +337,9 @@ def test_putPermission_InternalServerError(self): **ACCESS_CONTROL_CONFIG ) app.config.foca = base_config - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection with app.test_request_context(json=deepcopy(MOCK_RULE_INVALID)): with pytest.raises(InternalServerError): @@ -335,8 +353,9 @@ def test_putPermission_BadRequest(self): **ACCESS_CONTROL_CONFIG ) app.config.foca = base_config - app.config.foca.db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config.foca.db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection with app.test_request_context(json=""): with pytest.raises(BadRequest): diff --git a/tests/security/access_control/test_register_access_control.py b/tests/security/access_control/test_register_access_control.py index 376819af..34650420 100644 --- a/tests/security/access_control/test_register_access_control.py +++ b/tests/security/access_control/test_register_access_control.py @@ -6,9 +6,7 @@ from unittest import TestCase import pytest -from foca.security.access_control.register_access_control import ( - check_permissions -) +from foca.security.access_control.register_access_control import check_permissions from foca.security.access_control.foca_casbin_adapter.adapter import Adapter from foca.errors.exceptions import Forbidden from foca.models.config import AccessControlConfig, Config, MongoConfig @@ -16,7 +14,7 @@ ACCESS_CONTROL_CONFIG, MOCK_REQUEST, MONGO_CONFIG, - MOCK_PERMISSION + MOCK_PERMISSION, ) @@ -45,33 +43,26 @@ def test_check_permission_allowed(self): """Test to check only valid user requests are permitted via enforcer.""" app = Flask(__name__) - app.config["FOCA"] = Config( - db=self.db, - access_control=self.access_control - ) - app.config["FOCA"].db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config["FOCA"] = Config(db=self.db, access_control=self.access_control) + app.config["FOCA"].db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection app.config["casbin_adapter"] = Adapter( uri=f"mongodb://localhost:{self.db_port}/", dbname=self.access_db, - collection=self.access_col - ) - app.config["casbin_adapter"].save_policy_line( - ptype="p", - rule=MOCK_PERMISSION + collection=self.access_col, ) + app.config["casbin_adapter"].save_policy_line(ptype="p", rule=MOCK_PERMISSION) app.config["CASBIN_MODEL"] = self.access_control.model app.config["CASBIN_OWNER_HEADERS"] = self.access_control.owner_headers - app.config["CASBIN_USER_NAME_HEADERS"] = self.access_control.\ - user_headers + app.config["CASBIN_USER_NAME_HEADERS"] = self.access_control.user_headers @check_permissions def mock_func(): return "pass" with app.test_request_context( - environ_base=MOCK_REQUEST, - headers={"X-User": "alice"} + environ_base=MOCK_REQUEST, headers={"X-User": "alice"} ): response = mock_func() assert response == "pass" @@ -84,29 +75,25 @@ def test_check_permission_allowed_casbin_permission_not_found(self): """Test to check only user forbidden in case permission is not present.""" app = Flask(__name__) - app.config["FOCA"] = Config( - db=self.db, - access_control=self.access_control - ) - app.config["FOCA"].db.dbs[self.access_db].collections[self.access_col]\ - .client = mongomock.MongoClient().db.collection + app.config["FOCA"] = Config(db=self.db, access_control=self.access_control) + app.config["FOCA"].db.dbs[self.access_db].collections[ + self.access_col + ].client = mongomock.MongoClient().db.collection app.config["casbin_adapter"] = Adapter( uri=f"mongodb://localhost:{self.db_port}/", dbname=self.access_db, - collection=self.access_col + collection=self.access_col, ) app.config["CASBIN_MODEL"] = self.access_control.model app.config["CASBIN_OWNER_HEADERS"] = self.access_control.owner_headers - app.config["CASBIN_USER_NAME_HEADERS"] = self.access_control.\ - user_headers + app.config["CASBIN_USER_NAME_HEADERS"] = self.access_control.user_headers @check_permissions def mock_func(): return "pass" with app.test_request_context( - environ_base=MOCK_REQUEST, - headers={"X-Admin": "alice"} + environ_base=MOCK_REQUEST, headers={"X-Admin": "alice"} ): with pytest.raises(Forbidden): mock_func() diff --git a/tests/security/test_auth.py b/tests/security/test_auth.py index d995eac8..675b84bb 100644 --- a/tests/security/test_auth.py +++ b/tests/security/test_auth.py @@ -5,11 +5,11 @@ from connexion.exceptions import Unauthorized from flask import Flask -from jwt.exceptions import (InvalidKeyError, InvalidTokenError) +from jwt.exceptions import InvalidKeyError, InvalidTokenError import pytest from requests.exceptions import ConnectionError -from foca.models.config import (Config, ValidationChecksEnum) +from foca.models.config import Config, ValidationChecksEnum from foca.security.auth import ( _get_public_keys, _validate_jwt_userinfo, @@ -18,19 +18,19 @@ ) DICT_EMPTY = {} -MOCK_BYTES = b'my-mock-bytes' +MOCK_BYTES = b"my-mock-bytes" MOCK_CLAIMS_ISSUER = {"iss": "some-mock-issuer"} -MOCK_USER_ID = '1234567890' +MOCK_USER_ID = "1234567890" MOCK_CLAIMS_NO_SUB = { - 'azp': 'my-azp', - 'scope': 'email openid profile', - 'iss': 'https://my.issuer.org/oidc/', - 'exp': 1000010000, - 'iat': 1000000000, - 'jti': 'my-jti', + "azp": "my-azp", + "scope": "email openid profile", + "iss": "https://my.issuer.org/oidc/", + "exp": 1000010000, + "iat": 1000000000, + "jti": "my-jti", } MOCK_CLAIMS = deepcopy(MOCK_CLAIMS_NO_SUB) -MOCK_CLAIMS['sub'] = 'user@issuer.org' +MOCK_CLAIMS["sub"] = "user@issuer.org" MOCK_KEYS = { "abc": ( "uVHPfUHVEzpgOnDNi3e2pVsbK1hsINsTy_1mMT7sxDyP-1eQSjzYsGSUJ3GH" @@ -128,51 +128,50 @@ class TestValidateToken: def test_success_all_validation_checks(self, monkeypatch): """Test for validating token successfully via all methods.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) - request = MagicMock(name='requests') + setattr(app.config, "foca", Config()) + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = { - 'userinfo_endpoint': MOCK_URL, - 'jwks_uri': MOCK_URL, + "userinfo_endpoint": MOCK_URL, + "jwks_uri": MOCK_URL, } - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_userinfo', + "foca.security.auth._validate_jwt_userinfo", lambda **kwargs: None, ) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_public_key', + "foca.security.auth._validate_jwt_public_key", lambda **kwargs: None, ) with app.test_request_context(headers=MOCK_HEADERS): res = validate_token(token=MOCK_TOKEN_HEADER_KID) - assert res['user_id'] == MOCK_USER_ID + assert res["user_id"] == MOCK_USER_ID def test_success_any_validation_check(self, monkeypatch): """Test for validating token successfully via any method.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) - app.config.foca.security.auth.\ - validation_checks = ValidationChecksEnum.any - request = MagicMock(name='requests') + setattr(app.config, "foca", Config()) + app.config.foca.security.auth.validation_checks = ValidationChecksEnum.any + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = { - 'userinfo_endpoint': MOCK_URL, - 'jwks_uri': MOCK_URL, + "userinfo_endpoint": MOCK_URL, + "jwks_uri": MOCK_URL, } - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_userinfo', + "foca.security.auth._validate_jwt_userinfo", lambda **kwargs: None, ) with app.test_request_context(headers=MOCK_HEADERS): res = validate_token(token=MOCK_TOKEN_HEADER_KID) - assert res['user_id'] == MOCK_USER_ID + assert res["user_id"] == MOCK_USER_ID def test_no_validation_methods(self): """Test for failed validation due to missing validation methods.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) app.config.foca.security.auth.validation_methods = [] with app.test_request_context(headers=MOCK_HEADERS): with pytest.raises(Unauthorized): @@ -181,7 +180,7 @@ def test_no_validation_methods(self): def test_invalid_token(self): """Test for failed validation due to invalid token.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) with app.test_request_context(headers=MOCK_HEADERS): with pytest.raises(Unauthorized): validate_token(token=MOCK_TOKEN_INVALID) @@ -189,9 +188,9 @@ def test_invalid_token(self): def test_no_claims(self, monkeypatch): """Test for token with no issuer claim.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) monkeypatch.setattr( - 'jwt.decode', + "jwt.decode", lambda *args, **kwargs: {}, ) with app.test_request_context(headers=MOCK_HEADERS): @@ -201,11 +200,8 @@ def test_no_claims(self, monkeypatch): def test_oidc_config_unavailable(self, monkeypatch): """Test for mocking an unavailable OIDC configuration server.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) - monkeypatch.setattr( - 'requests.get', - lambda **kwargs: _raise(ConnectionError) - ) + setattr(app.config, "foca", Config()) + monkeypatch.setattr("requests.get", lambda **kwargs: _raise(ConnectionError)) with app.test_request_context(headers=MOCK_HEADERS): with pytest.raises(Unauthorized): validate_token(token=MOCK_TOKEN_HEADER_KID) @@ -213,24 +209,24 @@ def test_oidc_config_unavailable(self, monkeypatch): def test_success_no_subject_claim(self, monkeypatch): """Test for validating token without subject claim.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + setattr(app.config, "foca", Config()) monkeypatch.setattr( - 'jwt.decode', + "jwt.decode", lambda *args, **kwargs: MOCK_CLAIMS_NO_SUB, ) - request = MagicMock(name='requests') + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = { - 'userinfo_endpoint': MOCK_URL, - 'jwks_uri': MOCK_URL, + "userinfo_endpoint": MOCK_URL, + "jwks_uri": MOCK_URL, } - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_userinfo', + "foca.security.auth._validate_jwt_userinfo", lambda **kwargs: None, ) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_public_key', + "foca.security.auth._validate_jwt_public_key", lambda **kwargs: None, ) with app.test_request_context(headers=MOCK_HEADERS): @@ -241,20 +237,20 @@ def test_fail_all_validation_checks_all_required(self, monkeypatch): """Test for all token validation methods failing when all methods are required to pass.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) - request = MagicMock(name='requests') + setattr(app.config, "foca", Config()) + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = { - 'userinfo_endpoint': MOCK_URL, - 'jwks_uri': MOCK_URL, + "userinfo_endpoint": MOCK_URL, + "jwks_uri": MOCK_URL, } - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_userinfo', + "foca.security.auth._validate_jwt_userinfo", lambda **kwargs: _raise(ConnectionError), ) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_public_key', + "foca.security.auth._validate_jwt_public_key", lambda **kwargs: _raise(Unauthorized), ) with app.test_request_context(headers=MOCK_HEADERS): @@ -265,22 +261,21 @@ def test_fail_all_validation_checks_any_required(self, monkeypatch): """Test for all token validation methods failing when any method is required to pass.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) - app.config.foca.security.auth.\ - validation_checks = ValidationChecksEnum.any - request = MagicMock(name='requests') + setattr(app.config, "foca", Config()) + app.config.foca.security.auth.validation_checks = ValidationChecksEnum.any + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = { - 'userinfo_endpoint': MOCK_URL, - 'jwks_uri': MOCK_URL, + "userinfo_endpoint": MOCK_URL, + "jwks_uri": MOCK_URL, } - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_userinfo', + "foca.security.auth._validate_jwt_userinfo", lambda **kwargs: _raise(ConnectionError), ) monkeypatch.setattr( - 'foca.security.auth._validate_jwt_public_key', + "foca.security.auth._validate_jwt_public_key", lambda **kwargs: _raise(Unauthorized), ) with app.test_request_context(headers=MOCK_HEADERS): @@ -293,10 +288,10 @@ class TestValidateJwtUserinfo: def test_success(self, monkeypatch): """Test for validating a token successfully.""" - request = MagicMock(name='requests') + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = {} - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) res = _validate_jwt_userinfo( token=MOCK_TOKEN, url=MOCK_URL, @@ -305,10 +300,7 @@ def test_success(self, monkeypatch): def test_ConnectionError(self, monkeypatch): """Test for being unable to connect to user info endpoint.""" - monkeypatch.setattr( - 'requests.get', - lambda **kwargs: _raise(ConnectionError) - ) + monkeypatch.setattr("requests.get", lambda **kwargs: _raise(ConnectionError)) with pytest.raises(ConnectionError): _validate_jwt_userinfo( token=MOCK_TOKEN, @@ -322,11 +314,11 @@ class TestValidateJwtPublicKey: def test_success(self, monkeypatch): """Test for validating a token successfully.""" monkeypatch.setattr( - 'foca.security.auth._get_public_keys', + "foca.security.auth._get_public_keys", lambda **kwargs: MOCK_KEYS, ) monkeypatch.setattr( - 'jwt.decode', + "jwt.decode", lambda *args, **kwargs: MOCK_CLAIMS, ) res = _validate_jwt_public_key( @@ -339,11 +331,11 @@ def test_success(self, monkeypatch): def test_InvalidKeyError(self, monkeypatch): """Test for invalid key.""" monkeypatch.setattr( - 'foca.security.auth._get_public_keys', + "foca.security.auth._get_public_keys", lambda **kwargs: MOCK_KEYS, ) monkeypatch.setattr( - 'jwt.decode', + "jwt.decode", lambda **kwargs: _raise(InvalidKeyError), ) with pytest.raises(Unauthorized): @@ -355,11 +347,11 @@ def test_InvalidKeyError(self, monkeypatch): def test_InvalidTokenError(self, monkeypatch): """Test for invalid token.""" monkeypatch.setattr( - 'foca.security.auth._get_public_keys', + "foca.security.auth._get_public_keys", lambda **kwargs: MOCK_KEYS, ) monkeypatch.setattr( - 'jwt.decode', + "jwt.decode", lambda **kwargs: _raise(InvalidTokenError), ) with pytest.raises(Unauthorized): @@ -371,7 +363,7 @@ def test_InvalidTokenError(self, monkeypatch): def test_no_header_claims(self, monkeypatch): """Test for token without header claims.""" monkeypatch.setattr( - 'foca.security.auth._get_public_keys', + "foca.security.auth._get_public_keys", lambda **kwargs: MOCK_KEYS, ) with pytest.raises(Unauthorized): @@ -383,7 +375,7 @@ def test_no_header_claims(self, monkeypatch): def test_kid_mismatch(self, monkeypatch): """Test for token and JWK set with mismatching JWK identifiers.""" monkeypatch.setattr( - 'foca.security.auth._get_public_keys', + "foca.security.auth._get_public_keys", lambda **kwargs: MOCK_KEYS, ) with pytest.raises(KeyError): @@ -399,28 +391,25 @@ class TestGetPublicKeys: def test_success(self, monkeypatch): """Test for successfully fetching keys.""" mock_jwk_set = {"keys": [MOCK_JWK, {}]} - request = MagicMock(name='requests') + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = mock_jwk_set - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) res = _get_public_keys(url=MOCK_URL, pem=True) - assert MOCK_JWK['kid'] in res + assert MOCK_JWK["kid"] in res def test_ConnectionError(self, monkeypatch): """Test for being unable to connect to keys endpoint.""" - monkeypatch.setattr( - 'requests.get', - lambda **kwargs: _raise(ConnectionError) - ) + monkeypatch.setattr("requests.get", lambda **kwargs: _raise(ConnectionError)) with pytest.raises(ConnectionError): _get_public_keys(url=MOCK_URL) def test_non_public_key(self, monkeypatch): """Test for non-public keys.""" mock_jwk_set = {"keys": [MOCK_JWK_PRIVATE]} - request = MagicMock(name='requests') + request = MagicMock(name="requests") request.status_code = 200 request.return_value.json.return_value = mock_jwk_set - monkeypatch.setattr('requests.get', request) + monkeypatch.setattr("requests.get", request) res = _get_public_keys(url=MOCK_URL) assert res == {} diff --git a/tests/security/test_cors.py b/tests/security/test_cors.py index 376b120f..e9567d72 100644 --- a/tests/security/test_cors.py +++ b/tests/security/test_cors.py @@ -10,6 +10,6 @@ def test_enable_cors(): """Test that CORS is called with app as a parameter.""" app = Flask(__name__) - with patch('foca.security.cors.CORS') as mock_cors: + with patch("foca.security.cors.CORS") as mock_cors: enable_cors(app) mock_cors.assert_called_once_with(app) diff --git a/tests/test_files/model_valid.py b/tests/test_files/model_valid.py index afad18c6..d1b74233 100644 --- a/tests/test_files/model_valid.py +++ b/tests/test_files/model_valid.py @@ -7,4 +7,5 @@ class CustomConfig(BaseModel): Args: param: Test parameter. """ - param: str = 'STRING' + + param: str = "STRING" diff --git a/tests/test_files/models_petstore.py b/tests/test_files/models_petstore.py index eee1c343..c442f27f 100644 --- a/tests/test_files/models_petstore.py +++ b/tests/test_files/models_petstore.py @@ -21,6 +21,7 @@ class Pet(BaseModel): name: The pet's name. tag: Optional tag for the pet. """ + id: int name: str tag: Optional[str] @@ -35,6 +36,7 @@ class Pets(BaseModel): Attributes: pets: List of pets. """ + pets: List[Pet] = [] @@ -49,5 +51,6 @@ class Error(BaseModel): code: Status code. message: Error message. """ + code: int message: str diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index aa80709a..b56166e4 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -10,13 +10,13 @@ def test_find_one_latest(): field. """ collection = mongomock.MongoClient().db.collection - obj1 = {'_id': 1, 'name': 'first'} - obj2 = {'_id': 2, 'name': 'seond'} - obj3 = {'_id': 3, 'name': 'third'} + obj1 = {"_id": 1, "name": "first"} + obj2 = {"_id": 2, "name": "seond"} + obj3 = {"_id": 3, "name": "third"} collection.insert_many([obj1, obj2, obj3]) res = find_one_latest(collection) - assert res == {'name': 'third'} + assert res == {"name": "third"} def test_find_one_latest_returns_None(): @@ -28,9 +28,9 @@ def test_find_one_latest_returns_None(): def test_find_id_latest(): """Test that find_id_latest return recently added id.""" collection = mongomock.MongoClient().db.collection - obj1 = {'_id': 1, 'name': 'first'} - obj2 = {'_id': 2, 'name': 'seond'} - obj3 = {'_id': 3, 'name': 'third'} + obj1 = {"_id": 1, "name": "first"} + obj2 = {"_id": 2, "name": "seond"} + obj3 = {"_id": 3, "name": "third"} collection.insert_many([obj1, obj2, obj3]) res = find_id_latest(collection) diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index a5d643a9..1c6c05c2 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -8,10 +8,10 @@ app = Flask(__name__) REQ = { - 'REQUEST_METHOD': 'GET', - 'PATH_INFO': '/', - 'SERVER_PROTOCOL': 'HTTP/1.1', - 'REMOTE_ADDR': '192.168.1.1', + "REQUEST_METHOD": "GET", + "PATH_INFO": "/", + "SERVER_PROTOCOL": "HTTP/1.1", + "REMOTE_ADDR": "192.168.1.1", } # Get logger instance @@ -24,12 +24,11 @@ def test_logging_decorator(caplog): @log_traffic def mock_func(): - return {'foo': 'bar'} + return {"foo": "bar"} with app.test_request_context(environ_base=REQ): mock_func() - assert 'Incoming request' in caplog.text \ - and 'Response to request' in caplog.text + assert "Incoming request" in caplog.text and "Response to request" in caplog.text def test_logging_decorator_log_level(caplog): @@ -37,11 +36,11 @@ def test_logging_decorator_log_level(caplog): @log_traffic(log_level=30) def mock_func(): - return {'foo': 'bar'} + return {"foo": "bar"} with app.test_request_context(environ_base=REQ): mock_func() - assert 'WARNING' in caplog.text + assert "WARNING" in caplog.text def test_logging_decorator_req_only(caplog): @@ -50,12 +49,13 @@ def test_logging_decorator_req_only(caplog): @log_traffic(log_response=False) def mock_func(): - return {'foo': 'bar'} + return {"foo": "bar"} with app.test_request_context(environ_base=REQ): mock_func() - assert 'Incoming request' in caplog.text \ - and 'Response to request' not in caplog.text + assert ( + "Incoming request" in caplog.text and "Response to request" not in caplog.text + ) def test_logging_decorator_res_only(caplog): @@ -64,9 +64,10 @@ def test_logging_decorator_res_only(caplog): @log_traffic(log_request=False) def mock_func(): - return {'foo': 'bar'} + return {"foo": "bar"} with app.test_request_context(environ_base=REQ): mock_func() - assert 'Incoming request' not in caplog.text \ - and 'Response to request' in caplog.text + assert ( + "Incoming request" not in caplog.text and "Response to request" in caplog.text + ) diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index a5f9a40e..4c384ee3 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -8,7 +8,6 @@ class TestGenerateId: - def test_default(self): """Use only default arguments.""" res = generate_id() @@ -21,8 +20,7 @@ def test_charset_literal_string(self): assert set(res) <= set(string.digits) def test_charset_literal_string_duplicates(self): - """Argument to `charset` is non-default literal string with duplicates. - """ + """Argument to `charset` is non-default literal string with duplicates.""" charset = string.digits + string.digits res = generate_id(charset=charset) assert set(res) <= set(string.digits)