diff --git a/src/huggingface_hub/commands/branch.py b/src/huggingface_hub/commands/branch.py new file mode 100644 index 0000000000..8914ff8c4e --- /dev/null +++ b/src/huggingface_hub/commands/branch.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to perform branch management with the CLI. + +Usage Examples: + # Create a new branch in a repository on huggingface.co + huggingface-cli branch create my-cool-model my-great-branch + + # List branches in a repositories on huggingface.co + huggingface-cli branch list my-cool-model + + # Delete a branch in a repository on huggingface.co + huggingface-cli branch delete my-cool-model my-great-branch +""" +import subprocess +from argparse import Namespace, _SubParsersAction + +from requests.exceptions import HTTPError + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.constants import ( + REPO_TYPES, + REPO_TYPES_URL_PREFIXES, +) +from huggingface_hub.hf_api import HfApi + +from ..utils import HfFolder +from ._cli_utils import ANSI + + +class BranchCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + branch_parser = parser.add_parser( + "branch", + help="{create, list, delete} commands to interact with your huggingface.co repo branches.", + ) + branch_subparsers = branch_parser.add_subparsers(help="huggingface.co repo branch related commands") + branch_create_parser = branch_subparsers.add_parser( + "create", help="Create a new branch for your repo on huggingface.co" + ) + branch_create_parser.add_argument( + "repo_id", + type=str, + help="The repository in which the branch will be created.", + ) + branch_create_parser.add_argument( + "branch", + type=str, + help="The name of the branch to create.", + ) + branch_create_parser.add_argument( + "--revision", + type=str, + help="The git revision to create the branch from.", + ) + branch_create_parser.add_argument( + "--type", + choices=["model", "dataset", "space"], + default="model", + help=( + 'Optional: type: set to "dataset" or "space" if creating a branch in a dataset or space, default is' + " model." + ), + ) + branch_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") + branch_create_parser.add_argument( + "-y", + "--yes", + action="store_true", + help="Optional: answer Yes to the prompt", + ) + branch_create_parser.set_defaults(func=lambda args: BranchCreateCommand(args)) + branch_list_parser = branch_subparsers.add_parser("list", help="List branches of the repo on huggingface.co") + branch_list_parser.add_argument( + "repo_id", + type=str, + help="The repository for which to list branches.", + ) + branch_list_parser.add_argument( + "--type", + choices=["model", "dataset", "space"], + default="model", + help="Type of repo to list branches for, default is model.", + ) + branch_list_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") + branch_list_parser.set_defaults(func=lambda args: BranchListCommand(args)) + branch_delete_parser = branch_subparsers.add_parser( + "delete", help="Delete a branch for your repo on huggingface.co" + ) + branch_delete_parser.add_argument( + "repo_id", + type=str, + help="Name of the repo in which to delete the branch.", + ) + branch_delete_parser.add_argument( + "branch", + type=str, + help="The name of the branch to delete.", + ) + branch_delete_parser.add_argument( + "--type", + choices=["model", "dataset", "space"], + default="model", + help="Type of the repo to delete the branch in, default is model.", + ) + branch_delete_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") + branch_delete_parser.add_argument( + "-y", + "--yes", + action="store_true", + help="Optional: answer Yes to the prompt", + ) + branch_delete_parser.set_defaults(func=lambda args: BranchDeleteCommand(args)) + + +class BaseBranchCommand: + def __init__(self, args: Namespace): + self.args = args + self._api = HfApi() + self.token = HfFolder.get_token() + if self.token is None: + print("Not logged in") + exit(1) + try: + stdout = subprocess.check_output(["git", "--version"]).decode("utf-8") + print(ANSI.gray(stdout.strip())) + except FileNotFoundError: + print("Looks like you do not have git installed, please install.") + + +class BranchCreateCommand(BaseBranchCommand): + def run(self): + try: + stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8") + print(ANSI.gray(stdout.strip())) + except FileNotFoundError: + print( + ANSI.red( + "Looks like you do not have git-lfs installed, please install." + " You can install from https://git-lfs.github.com/." + " Then run `git lfs install` (you only have to do this once)." + ) + ) + print("") + + user = self._api.whoami(self.token)["name"] + namespace = self.args.organization if self.args.organization is not None else user + + repo_id = f"{namespace}/{self.args.repo_id}" + + if self.args.type not in REPO_TYPES: + print("Invalid repo --type") + exit(1) + + if self.args.type in REPO_TYPES_URL_PREFIXES: + prefixed_repo_id = REPO_TYPES_URL_PREFIXES[self.args.type] + repo_id + else: + prefixed_repo_id = repo_id + + print(f"You are about to create branch {ANSI.bold(self.args.branch)} on {ANSI.bold(prefixed_repo_id)}") + + if not self.args.yes: + choice = input("Proceed? [Y/n] ").lower() + if not (choice == "" or choice == "y" or choice == "yes"): + print("Abort") + exit() + try: + self._api.create_branch( + repo_id=repo_id, + branch=self.args.branch, + revision=self.args.revision, + token=self.token, + repo_type=self.args.type, + ) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + print("\nYour branch now lives at:") + print(f" {prefixed_repo_id}/branch/{ANSI.bold(self.args.branch)}") + + +class BranchListCommand(BaseBranchCommand): + def run(self): + self.type = self.args.type + user = self._api.whoami(self.token)["name"] + namespace = self.args.organization if self.args.organization is not None else user + + repo_id = f"{namespace}/{self.args.repo_id}" + + try: + refs = self._api.list_repo_refs( + repo_id=repo_id, + repo_type=self.type, + ) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + print("\nYour branches:") + for branch in refs.branches: + print(f" {ANSI.bold(branch.name)}") + print("") + + +class BranchDeleteCommand(BaseBranchCommand): + def run(self): + user = self._api.whoami(self.token)["name"] + namespace = self.args.organization if self.args.organization is not None else user + + repo_id = f"{namespace}/{self.args.repo_id}" + + if self.args.type not in REPO_TYPES: + print("Invalid repo --type") + exit(1) + + if self.args.type in REPO_TYPES_URL_PREFIXES: + prefixed_repo_id = REPO_TYPES_URL_PREFIXES[self.args.type] + repo_id + else: + prefixed_repo_id = repo_id + + print(f"You are about to delete branch {ANSI.bold(self.args.branch)} on {ANSI.bold(prefixed_repo_id)}") + + if not self.args.yes: + choice = input("Proceed? [Y/n] ").lower() + if not (choice == "" or choice == "y" or choice == "yes"): + print("Abort") + exit() + try: + self._api.delete_branch( + repo_id=repo_id, branch=self.args.branch, token=self.token, repo_type=self.args.type + ) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + print("\nYour branch has been deleted.") + print("") diff --git a/src/huggingface_hub/commands/huggingface_cli.py b/src/huggingface_hub/commands/huggingface_cli.py index 39b6dfe49a..cd66714268 100644 --- a/src/huggingface_hub/commands/huggingface_cli.py +++ b/src/huggingface_hub/commands/huggingface_cli.py @@ -15,6 +15,7 @@ from argparse import ArgumentParser +from huggingface_hub.commands.branch import BranchCommands from huggingface_hub.commands.delete_cache import DeleteCacheCommand from huggingface_hub.commands.download import DownloadCommand from huggingface_hub.commands.env import EnvironmentCommand @@ -29,6 +30,7 @@ def main(): commands_parser = parser.add_subparsers(help="huggingface-cli command helpers") # Register commands + BranchCommands.register_subcommand(commands_parser) EnvironmentCommand.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser) UploadCommand.register_subcommand(commands_parser) diff --git a/tests/test_cli.py b/tests/test_cli.py index 1549a3f2bb..e004eb127b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,6 +7,7 @@ from typing import Generator from unittest.mock import Mock, patch +from huggingface_hub.commands.branch import BranchCommands from huggingface_hub.commands.delete_cache import DeleteCacheCommand from huggingface_hub.commands.download import DownloadCommand from huggingface_hub.commands.scan_cache import ScanCacheCommand @@ -16,6 +17,42 @@ from .testing_utils import DUMMY_MODEL_ID +class TestBranchCommands(unittest.TestCase): + def setUp(self) -> None: + """ + Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`. + """ + self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") + commands_parser = self.parser.add_subparsers() + BranchCommands.register_subcommand(commands_parser) + + def test_repo_create_basic(self) -> None: + """Test `huggingface-cli create branch in dummy-repo`.""" + args = self.parser.parse_args(["branch", "create", DUMMY_MODEL_ID, "my-great-branch"]) + self.assertEqual(args.repo_id, DUMMY_MODEL_ID) + self.assertEqual(args.branch, "my-great-branch") + self.assertEqual(args.revision, None) + self.assertEqual(args.type, "model") + self.assertIsNone(args.organization) + self.assertFalse(args.yes) + + def test_repo_list_basic(self) -> None: + """Test `huggingface-cli list branches in dummy-repo`.""" + args = self.parser.parse_args(["branch", "list", DUMMY_MODEL_ID]) + self.assertEqual(args.repo_id, DUMMY_MODEL_ID) + self.assertEqual(args.type, "model") + self.assertIsNone(args.organization) + + def test_repo_delete_basic(self) -> None: + """Test `huggingface-cli delete branch in dummy-repo`.""" + args = self.parser.parse_args(["branch", "delete", DUMMY_MODEL_ID, "my-great-branch"]) + self.assertEqual(args.repo_id, DUMMY_MODEL_ID) + self.assertEqual(args.branch, "my-great-branch") + self.assertEqual(args.type, "model") + self.assertIsNone(args.organization) + self.assertFalse(args.yes) + + class TestCacheCommand(unittest.TestCase): def setUp(self) -> None: """