diff --git a/spkrepo/views/api.py b/spkrepo/views/api.py index 639c9e5..a3fe5b1 100644 --- a/spkrepo/views/api.py +++ b/spkrepo/views/api.py @@ -3,6 +3,7 @@ import os import re import shutil +import redis from functools import wraps from flask import Blueprint, _request_ctx_stack, current_app, request @@ -28,6 +29,8 @@ from ..utils import SPK api = Blueprint("api", __name__) +redis_client = redis.Redis(host='localhost', port=6379, db=0) +redis_lock = redis_client.lock('my_lock') # regexes firmware_re = re.compile(r"^(?P\d\.\d)-(?P\d{3,6})$") @@ -127,93 +130,135 @@ def post(self): if firmware is None: abort(422, message="Unknown firmware") + # path to save files + data_path = current_app.config["DATA_PATH"] + # Package - create_package = False - package = Package.find(spk.info["package"]) - if package is None: - if not current_user.has_role("package_admin"): - abort(403, message="Insufficient permissions to create new packages") - create_package = True - package = Package(name=spk.info["package"], author=current_user) - elif ( - not current_user.has_role("package_admin") - and current_user not in package.maintainers - ): - abort(403, message="Insufficient permissions on this package") + try: + # Acquire the Redis lock + redis_lock.acquire() + + create_package = False + package = Package.find(spk.info["package"]) + if package is None: + if not current_user.has_role("package_admin"): + abort(403, message="Insufficient permissions to create new packages") + create_package = True + package = Package(name=spk.info["package"], author=current_user) + elif ( + not current_user.has_role("package_admin") + and current_user not in package.maintainers + ): + abort(403, message="Insufficient permissions on this package") + + if create_package: + try: + os.mkdir(os.path.join(data_path, package.name)) + except Exception as e: # pragma: no cover + shutil.rmtree(os.path.join(data_path, package.name), ignore_errors=True) + abort(500, message="Failed to create directory", details=str(e)) + # Add package to database + db.session.add(package) + db.session.commit() + + finally: + # Release the Redis lock + redis_lock.release() # Version - create_version = False - match = version_re.match(spk.info["version"]) - if not match: - abort(422, message="Invalid version") - # TODO: check discrepencies with what's in the database - version = {v.version: v for v in package.versions}.get( - int(match.group("version")) - ) - if version is None: - create_version = True - version_startable = None - if spk.info.get("startable") is False or spk.info.get("ctl_stop") is False: - version_startable = False - elif spk.info.get("startable") is True or spk.info.get("ctl_stop") is True: - version_startable = True - version = Version( - package=package, - upstream_version=match.group("upstream_version"), - version=int(match.group("version")), - changelog=spk.info.get("changelog"), - report_url=spk.info.get("report_url"), - distributor=spk.info.get("distributor"), - distributor_url=spk.info.get("distributor_url"), - maintainer=spk.info.get("maintainer"), - maintainer_url=spk.info.get("maintainer_url"), - dependencies=spk.info.get("install_dep_packages"), - conf_dependencies=spk.conf_dependencies, - conflicts=spk.info.get("install_conflict_packages"), - conf_conflicts=spk.conf_conflicts, - conf_privilege=spk.conf_privilege, - conf_resource=spk.conf_resource, - install_wizard="install" in spk.wizards, - upgrade_wizard="upgrade" in spk.wizards, - startable=version_startable, - license=spk.license, + try: + # Acquire the Redis lock + redis_lock.acquire() + + create_version = False + match = version_re.match(spk.info["version"]) + if not match: + abort(422, message="Invalid version") + # TODO: check discrepencies with what's in the database + version = {v.version: v for v in package.versions}.get( + int(match.group("version")) ) + if version is None: + create_version = True + version_startable = None + if spk.info.get("startable") is False or spk.info.get("ctl_stop") is False: + version_startable = False + elif spk.info.get("startable") is True or spk.info.get("ctl_stop") is True: + version_startable = True + version = Version( + package=package, + upstream_version=match.group("upstream_version"), + version=int(match.group("version")), + changelog=spk.info.get("changelog"), + report_url=spk.info.get("report_url"), + distributor=spk.info.get("distributor"), + distributor_url=spk.info.get("distributor_url"), + maintainer=spk.info.get("maintainer"), + maintainer_url=spk.info.get("maintainer_url"), + dependencies=spk.info.get("install_dep_packages"), + conf_dependencies=spk.conf_dependencies, + conflicts=spk.info.get("install_conflict_packages"), + conf_conflicts=spk.conf_conflicts, + conf_privilege=spk.conf_privilege, + conf_resource=spk.conf_resource, + install_wizard="install" in spk.wizards, + upgrade_wizard="upgrade" in spk.wizards, + startable=version_startable, + license=spk.license, + ) - for key, value in spk.info.items(): - if key == "install_dep_services": - for service_name in value.split(): - version.service_dependencies.append(Service.find(service_name)) - elif key == "displayname": - version.displaynames["enu"] = DisplayName( - language=Language.find("enu"), displayname=value - ) - elif key.startswith("displayname_"): - language = Language.find(key.split("_", 1)[1]) - if not language: - abort(422, message="Unknown INFO displayname language") - version.displaynames[language.code] = DisplayName( - language=language, displayname=value - ) - elif key == "description": - version.descriptions["enu"] = Description( - description=value, language=Language.find("enu") - ) - elif key.startswith("description_"): - language = Language.find(key.split("_", 1)[1]) - if not language: - abort(422, message="Unknown INFO description language") - version.descriptions[language.code] = Description( - language=language, description=value + for key, value in spk.info.items(): + if key == "install_dep_services": + for service_name in value.split(): + version.service_dependencies.append(Service.find(service_name)) + elif key == "displayname": + version.displaynames["enu"] = DisplayName( + language=Language.find("enu"), displayname=value + ) + elif key.startswith("displayname_"): + language = Language.find(key.split("_", 1)[1]) + if not language: + abort(422, message="Unknown INFO displayname language") + version.displaynames[language.code] = DisplayName( + language=language, displayname=value + ) + elif key == "description": + version.descriptions["enu"] = Description( + description=value, language=Language.find("enu") + ) + elif key.startswith("description_"): + language = Language.find(key.split("_", 1)[1]) + if not language: + abort(422, message="Unknown INFO description language") + version.descriptions[language.code] = Description( + language=language, description=value + ) + + # Icon + for size, icon in spk.icons.items(): + version.icons[size] = Icon( + path=os.path.join( + package.name, str(version.version), "icon_%s.png" % size + ), + size=size, ) - # Icon - for size, icon in spk.icons.items(): - version.icons[size] = Icon( - path=os.path.join( - package.name, str(version.version), "icon_%s.png" % size - ), - size=size, - ) + if create_version: + try: + os.mkdir(os.path.join(data_path, package.name, str(version.version))) + except Exception as e: # pragma: no cover + shutil.rmtree( + os.path.join(data_path, package.name, str(version.version)), + ignore_errors=True, + ) + abort(500, message="Failed to create directory", details=str(e)) + # Add version to database + db.session.add(version) + db.session.commit() + + finally: + # Release the Redis lock + redis_lock.release() # Build if version.id: @@ -254,27 +299,15 @@ def post(self): # save files try: - data_path = current_app.config["DATA_PATH"] - if create_package: - os.mkdir(os.path.join(data_path, package.name)) if create_version: - os.mkdir(os.path.join(data_path, package.name, str(version.version))) for size, icon in build.version.icons.items(): icon.save(spk.icons[size]) build.save(spk.stream) except Exception as e: # pragma: no cover - if create_package: - shutil.rmtree(os.path.join(data_path, package.name), ignore_errors=True) - elif create_version: - shutil.rmtree( - os.path.join(data_path, package.name, str(version.version)), - ignore_errors=True, - ) - else: - try: - os.remove(os.path.join(data_path, build.path)) - except OSError: - pass + try: + os.remove(os.path.join(data_path, build.path)) + except OSError: + pass abort(500, message="Failed to save files", details=str(e)) # insert the package into database