-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_cache.py
More file actions
46 lines (38 loc) · 1.42 KB
/
model_cache.py
File metadata and controls
46 lines (38 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
"""Tool to maintaining model cache."""
import argparse
from huggingface_hub import scan_cache_dir, snapshot_download
MODEL_ID="Tongyi-MAI/Z-Image-Turbo"
def list_commit_hashes(model: str):
"""Returns a list of hashes for a model,"""
hashes = []
info = scan_cache_dir()
for repo in info.repos:
if repo.repo_id == model:
for ref in repo.refs.values():
hashes.append(ref.commit_hash)
return hashes
def main():
"""Maintain model cache."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
subparsers.add_parser("list")
remove_parser = subparsers.add_parser("remove")
remove_parser.add_argument("-m", "--model", type=str, default=MODEL_ID)
update_parser = subparsers.add_parser("update")
update_parser.add_argument("-m", "--model", type=str, default=MODEL_ID)
args = parser.parse_args()
if args.command in [None, "list"]:
info = scan_cache_dir()
print(info.export_as_table(verbosity=1))
elif args.command == "remove":
hashes = list_commit_hashes(args.model)
if len(hashes) > 0:
info = scan_cache_dir()
delete_strategy = info.delete_revisions(*hashes)
delete_strategy.execute()
print(hashes)
elif args.command == "update":
snapshot_download(repo_id=args.model)
if __name__ == "__main__":
main()