diff --git a/odev/_version.py b/odev/_version.py index 1335cae8..541da10a 100644 --- a/odev/_version.py +++ b/odev/_version.py @@ -22,4 +22,4 @@ # or merged change. # ------------------------------------------------------------------------------ -__version__ = "4.19.1" +__version__ = "4.20.0" diff --git a/odev/commands/database/database.py b/odev/commands/database/database.py new file mode 100644 index 00000000..07c3e935 --- /dev/null +++ b/odev/commands/database/database.py @@ -0,0 +1,111 @@ +from odev.common import args, progress +from odev.common.commands import GitCommand, LocalDatabaseCommand +from odev.common.connectors import GitConnector +from odev.common.logging import logging +from odev.common.python import PythonEnv + + +logger = logging.getLogger(__name__) + + +class DatabaseSetCommand(LocalDatabaseCommand, GitCommand): + """Edit local databases' parameters.""" + + _name = "database" + _aliases = ["db"] + + set_repository = args.String( + aliases=["--set-repo"], + description="Change the repository linked to the database, in the format /.", + metavar="REPOSITORY", + ) + remove_repository = args.Flag( + aliases=["--remove-repo"], + description="Remove the repository linked to the database.", + ) + + set_venv = args.String( + aliases=["--set-venv"], + description="Change the virtualenv linked to the database.", + metavar="VENV", + ) + remove_venv = args.Flag( + aliases=["--remove-venv"], + description="Remove the virtualenv linked to the database.", + ) + + set_worktree = args.String( + aliases=["--set-worktree"], + description="Change the worktree linked to the database.", + metavar="WORKTREE", + ) + remove_worktree = args.Flag( + aliases=["--remove-worktree"], + description="Remove the worktree linked to the database.", + ) + + whitelist = args.FlagOptional( + aliases=["--whitelist"], + description="Whitelist or unwhitelist the database.", + ) + + @classmethod + def prepare_command(cls, *args, **kwargs) -> None: + super().prepare_command(*args, **kwargs) + cls.remove_argument("version") + + def run(self): + self.args.version = None + + with progress.spinner("Setting database parameters"): + self._set_values() + self._remove_values() + + def _set_values(self): + if self.args.set_repository: + repo = GitConnector(self.args.set_repository) + + if not repo.exists and self.console.confirm("Repository not found locally, clone now?"): + self.odev.run_command("clone", repo.name) + + self.store.databases.set_value(self._database, "repository", f"{self.args.set_repository!r}") + self.store.databases.set_value(self._database, "branch", "NULL") + logger.info(f"Repository set to {self.args.set_repository!r}") + + if self.args.set_venv: + venv = PythonEnv(self.args.set_venv) + + if venv.exists: + self.store.databases.set_value(self._database, "virtualenv", f"{self.args.set_venv!r}") + logger.info(f"Virtualenv set to {self.args.set_venv!r}") + else: + logger.error(f"Virtualenv {self.args.set_venv!r} not found, please create it and retry") + + if self.args.set_worktree: + if self.args.set_worktree in self.grouped_worktrees: + self.store.databases.set_value(self._database, "worktree", f"{self.args.set_worktree!r}") + logger.info(f"Worktree set to {self.args.set_worktree!r}") + else: + logger.error(f"Worktree {self.args.set_worktree!r} not found, please create it and retry") + + if self.args.whitelist is True: + self.store.databases.set_value(self._database, "whitelisted", "TRUE") + logger.info("Database whitelisted") + + def _remove_values(self): + if self.args.remove_repository: + self.store.databases.set_value(self._database, "repository", "NULL") + self.store.databases.set_value(self._database, "branch", "NULL") + logger.info("Repository removed") + + if self.args.remove_venv: + self.store.databases.set_value(self._database, "virtualenv", "NULL") + logger.info("Virtualenv removed") + + if self.args.remove_worktree: + self.store.databases.set_value(self._database, "worktree", "NULL") + logger.info("Worktree removed") + + if self.args.whitelist is False: + self.store.databases.set_value(self._database, "whitelisted", "FALSE") + logger.info("Database unwhitelisted") diff --git a/odev/common/arguments.py b/odev/common/arguments.py index 53e3052f..28191ea4 100644 --- a/odev/common/arguments.py +++ b/odev/common/arguments.py @@ -2,11 +2,9 @@ import pathlib import re +from argparse import BooleanOptionalAction from collections.abc import MutableMapping -from typing import ( - Any, - Literal, -) +from typing import Any, Literal class Argument: @@ -201,7 +199,37 @@ def __init__( name=name, aliases=aliases, description=description, - action="store_false" if default is True else "store_true", + action=kwargs.pop("action", None) or ("store_false" if default is True else "store_true"), + **kwargs, + ) + + +class FlagOptional(Flag): + """Flag with a boolean value and automatic counter option (--flag and --no-flag).""" + + def __init__( + self, + name: str | None = None, + aliases: list[str] | None = None, + description: str | None = None, + **kwargs: Any, + ) -> None: + """Add a flag that has a boolean value which depends on whether it was passed in the command line. + + The default value is inverted if the flag is set. + :param name: The name of the argument, will be used in the help command and in the command's class `args` attribute. + :param aliases: The aliases for the argument. + :param description: A description for the argument, will be displayed in the `help` command. + :param default: The default value for the argument; a default value of `False` will result in the argument + being set to `True` if present in the CLI arguments. + :param kwargs: Additional keyword arguments to pass to the ArgumentParser. + See: https://docs.python.org/3/library/argparse.html#quick-links-for-add-argument + """ + super().__init__( + name=name, + aliases=aliases, + description=description, + action=BooleanOptionalAction, **kwargs, ) diff --git a/odev/common/store/tables/databases.py b/odev/common/store/tables/databases.py index b9cf26e7..9c9e4951 100644 --- a/odev/common/store/tables/databases.py +++ b/odev/common/store/tables/databases.py @@ -106,6 +106,17 @@ def set(self, database: Database, arguments: str | None = None): """ ) + def set_value(self, database: Database, key: str, value: str): + """Set a value for a database.""" + self.database.query( + f""" + UPDATE {self.name} + SET {key} = {value} + WHERE name = {database.name!r} + AND platform = {database.platform.name!r} + """ + ) + def delete(self, database: Database): """Delete the saved values of a database.""" self.database.query(