Skip to content

Added the ability to track progress of individual file downloads in snapshot_download() through inner_tqdm_class addition #2718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union, Type

import requests
from tqdm.auto import tqdm as base_tqdm
Expand Down Expand Up @@ -36,7 +36,9 @@ def snapshot_download(
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[base_tqdm] = None,
tqdm_class: Optional[Type[base_tqdm]] = None,
skip_outer_tqdm: Optional[bool] = False,
inner_tqdm_class: Optional[Type[base_tqdm]] = None,
headers: Optional[Dict[str, str]] = None,
endpoint: Optional[str] = None,
# Deprecated args
Expand Down Expand Up @@ -285,6 +287,7 @@ def _inner_hf_hub_download(repo_file: str):
force_download=force_download,
token=token,
headers=headers,
tqdm_class=inner_tqdm_class,
)

if constants.HF_HUB_ENABLE_HF_TRANSFER:
Expand All @@ -293,14 +296,18 @@ def _inner_hf_hub_download(repo_file: str):
for file in filtered_repo_files:
_inner_hf_hub_download(file)
else:
thread_map(
_inner_hf_hub_download,
filtered_repo_files,
desc=f"Fetching {len(filtered_repo_files)} files",
max_workers=max_workers,
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
tqdm_class=tqdm_class or hf_tqdm,
)
if skip_outer_tqdm:
with ThreadPoolExecutor(max_workers=max_workers) as ex:
list(ex.map(_inner_hf_hub_download, filtered_repo_files))
else:
thread_map(
_inner_hf_hub_download,
filtered_repo_files,
desc=f"Fetching {len(filtered_repo_files)} files",
max_workers=max_workers,
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
tqdm_class=tqdm_class or hf_tqdm,
)

if local_dir is not None:
return str(os.path.realpath(local_dir))
Expand Down
42 changes: 33 additions & 9 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union, Type
from urllib.parse import quote, urlparse

import requests
Expand Down Expand Up @@ -314,6 +314,7 @@ def http_get(
displayed_filename: Optional[str] = None,
_nb_retries: int = 5,
_tqdm_bar: Optional[tqdm] = None,
tqdm_class: Optional[Type[tqdm]] = None,
) -> None:
"""
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
Expand Down Expand Up @@ -397,14 +398,27 @@ def http_get(

# Stream file to buffer
progress_cm: tqdm = (
tqdm( # type: ignore[assignment]
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc=displayed_filename,
disable=is_tqdm_disabled(logger.getEffectiveLevel()),
name="huggingface_hub.http_get",
(
tqdm( # type: ignore[assignment]
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc=displayed_filename,
disable=is_tqdm_disabled(logger.getEffectiveLevel()),
name="huggingface_hub.http_get",
)
if tqdm_class is None
else
tqdm_class( # type: ignore[assignment]
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc=displayed_filename,
disable=is_tqdm_disabled(logger.getEffectiveLevel()),
name="huggingface_hub.http_get",
)
)
if _tqdm_bar is None
else contextlib.nullcontext(_tqdm_bar)
Expand Down Expand Up @@ -475,6 +489,7 @@ def http_get(
expected_size=expected_size,
_nb_retries=_nb_retries - 1,
_tqdm_bar=_tqdm_bar,
tqdm_class=tqdm_class,
)

if expected_size is not None and expected_size != temp_file.tell():
Expand Down Expand Up @@ -681,6 +696,7 @@ def hf_hub_download(
resume_download: Optional[bool] = None,
force_filename: Optional[str] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
tqdm_class: Optional[Type[tqdm]] = None,
) -> str:
"""Download a given file if it's not already present in the local cache.

Expand Down Expand Up @@ -855,6 +871,7 @@ def hf_hub_download(
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
tqdm_class=tqdm_class,
)
else:
return _hf_hub_download_to_cache_dir(
Expand All @@ -874,6 +891,7 @@ def hf_hub_download(
# Additional options
local_files_only=local_files_only,
force_download=force_download,
tqdm_class=tqdm_class,
)


Expand All @@ -895,6 +913,7 @@ def _hf_hub_download_to_cache_dir(
# Additional options
local_files_only: bool,
force_download: bool,
tqdm_class: Optional[Type[tqdm]],
) -> str:
"""Download a given file to a cache folder, if not already present.

Expand Down Expand Up @@ -1015,6 +1034,7 @@ def _hf_hub_download_to_cache_dir(
expected_size=expected_size,
filename=filename,
force_download=force_download,
tqdm_class=tqdm_class,
)
if not os.path.exists(pointer_path):
_create_symlink(blob_path, pointer_path, new_blob=True)
Expand All @@ -1041,6 +1061,7 @@ def _hf_hub_download_to_local_dir(
cache_dir: str,
force_download: bool,
local_files_only: bool,
tqdm_class: Optional[Type[tqdm]],
) -> str:
"""Download a given file to a local folder, if not already present.

Expand Down Expand Up @@ -1142,6 +1163,7 @@ def _hf_hub_download_to_local_dir(
expected_size=expected_size,
filename=filename,
force_download=force_download,
tqdm_class=tqdm_class,
)

write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
Expand Down Expand Up @@ -1498,6 +1520,7 @@ def _download_to_tmp_and_move(
expected_size: Optional[int],
filename: str,
force_download: bool,
tqdm_class: Optional[Type[tqdm]],
) -> None:
"""Download content from a URL to a destination path.

Expand Down Expand Up @@ -1547,6 +1570,7 @@ def _download_to_tmp_and_move(
resume_size=resume_size,
headers=headers,
expected_size=expected_size,
tqdm_class=tqdm_class,
)

logger.info(f"Download complete. Moving file to {destination_path}")
Expand Down