Skip to content

Commit ede63d1

Browse files
authored
Compact output in snapshot_download and hf download (#3523)
* Multi-threaded snapshot download * Compact output in snapshot_download * comment
1 parent 99bfce1 commit ede63d1

File tree

4 files changed

+89
-11
lines changed

4 files changed

+89
-11
lines changed

src/huggingface_hub/_snapshot_download.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
1919
from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
20-
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
20+
from .utils import OfflineModeIsEnabled, filter_repo_objects, is_tqdm_disabled, logging, validate_hf_hub_args
2121
from .utils import tqdm as hf_tqdm
2222

2323

@@ -379,12 +379,56 @@ def snapshot_download(
379379

380380
results: List[Union[str, DryRunFileInfo]] = []
381381

382+
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
383+
tqdm_class = tqdm_class or hf_tqdm
384+
385+
# Create a progress bar for the bytes downloaded
386+
# This progress bar is shared across threads/files and gets updated each time we fetch
387+
# metadata for a file.
388+
bytes_progress = tqdm_class(
389+
desc="Downloading (incomplete total...)",
390+
disable=is_tqdm_disabled(log_level=logger.getEffectiveLevel()),
391+
total=0,
392+
initial=0,
393+
unit="B",
394+
unit_scale=True,
395+
name="huggingface_hub.snapshot_download",
396+
)
397+
398+
class _AggregatedTqdm:
399+
"""Fake tqdm object to aggregate progress into the parent `bytes_progress` bar.
400+
401+
In practice the `_AggregatedTqdm` object won't be displayed, it's just used to update
402+
the `bytes_progress` bar from each thread/file download.
403+
"""
404+
405+
def __init__(self, *args, **kwargs):
406+
# Adjust the total of the parent progress bar
407+
total = kwargs.pop("total", None)
408+
if total is not None:
409+
bytes_progress.total += total
410+
bytes_progress.refresh()
411+
412+
# Adjust initial of the parent progress bar
413+
initial = kwargs.pop("initial", 0)
414+
if initial:
415+
bytes_progress.update(initial)
416+
417+
def __enter__(self):
418+
return self
419+
420+
def __exit__(self, exc_type, exc_value, traceback):
421+
pass
422+
423+
def update(self, n: Optional[Union[int, float]] = 1) -> None:
424+
bytes_progress.update(n)
425+
382426
# we pass the commit_hash to hf_hub_download
383427
# so no network call happens if we already
384428
# have the file locally.
385429
def _inner_hf_hub_download(repo_file: str) -> None:
386430
results.append(
387-
hf_hub_download( # type: ignore[no-matching-overload] # ty not happy, don't know why :/
431+
hf_hub_download( # type: ignore
388432
repo_id,
389433
filename=repo_file,
390434
repo_type=repo_type,
@@ -399,6 +443,7 @@ def _inner_hf_hub_download(repo_file: str) -> None:
399443
force_download=force_download,
400444
token=token,
401445
headers=headers,
446+
tqdm_class=_AggregatedTqdm, # type: ignore
402447
dry_run=dry_run,
403448
)
404449
)
@@ -408,10 +453,11 @@ def _inner_hf_hub_download(repo_file: str) -> None:
408453
filtered_repo_files,
409454
desc=tqdm_desc,
410455
max_workers=max_workers,
411-
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
412-
tqdm_class=tqdm_class or hf_tqdm,
456+
tqdm_class=tqdm_class,
413457
)
414458

459+
bytes_progress.set_description("Download complete")
460+
415461
if dry_run:
416462
assert all(isinstance(r, DryRunFileInfo) for r in results)
417463
return results # type: ignore

src/huggingface_hub/file_download.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from urllib.parse import quote, urlparse
1414

1515
import httpx
16+
from tqdm.auto import tqdm as base_tqdm
1617

1718
from . import constants
1819
from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata
@@ -348,6 +349,7 @@ def http_get(
348349
headers: Optional[dict[str, Any]] = None,
349350
expected_size: Optional[int] = None,
350351
displayed_filename: Optional[str] = None,
352+
tqdm_class: Optional[type[base_tqdm]] = None,
351353
_nb_retries: int = 5,
352354
_tqdm_bar: Optional[tqdm] = None,
353355
) -> None:
@@ -425,6 +427,7 @@ def http_get(
425427
total=total,
426428
initial=resume_size,
427429
name="huggingface_hub.http_get",
430+
tqdm_class=tqdm_class,
428431
_tqdm_bar=_tqdm_bar,
429432
)
430433

@@ -453,6 +456,7 @@ def http_get(
453456
resume_size=new_resume_size,
454457
headers=initial_headers,
455458
expected_size=expected_size,
459+
tqdm_class=tqdm_class,
456460
_nb_retries=_nb_retries - 1,
457461
_tqdm_bar=_tqdm_bar,
458462
)
@@ -472,6 +476,7 @@ def xet_get(
472476
headers: dict[str, str],
473477
expected_size: Optional[int] = None,
474478
displayed_filename: Optional[str] = None,
479+
tqdm_class: Optional[type[base_tqdm]] = None,
475480
_tqdm_bar: Optional[tqdm] = None,
476481
) -> None:
477482
"""
@@ -554,6 +559,7 @@ def token_refresher() -> tuple[str, int]:
554559
total=expected_size,
555560
initial=0,
556561
name="huggingface_hub.xet_get",
562+
tqdm_class=tqdm_class,
557563
_tqdm_bar=_tqdm_bar,
558564
)
559565

@@ -685,10 +691,10 @@ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
685691

686692
# Symlinks are not supported => let's move or copy the file.
687693
if new_blob:
688-
logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
694+
logger.debug(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
689695
shutil.move(abs_src, abs_dst, copy_function=_copy_no_matter_what)
690696
else:
691-
logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
697+
logger.debug(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
692698
shutil.copyfile(abs_src, abs_dst)
693699

694700

@@ -763,6 +769,7 @@ def hf_hub_download(
763769
local_files_only: bool = False,
764770
headers: Optional[dict[str, str]] = None,
765771
endpoint: Optional[str] = None,
772+
tqdm_class: Optional[type[base_tqdm]] = None,
766773
dry_run: Literal[False] = False,
767774
) -> str: ...
768775

@@ -786,6 +793,7 @@ def hf_hub_download(
786793
local_files_only: bool = False,
787794
headers: Optional[dict[str, str]] = None,
788795
endpoint: Optional[str] = None,
796+
tqdm_class: Optional[type[base_tqdm]] = None,
789797
dry_run: Literal[True] = True,
790798
) -> DryRunFileInfo: ...
791799

@@ -809,6 +817,7 @@ def hf_hub_download(
809817
local_files_only: bool = False,
810818
headers: Optional[dict[str, str]] = None,
811819
endpoint: Optional[str] = None,
820+
tqdm_class: Optional[type[base_tqdm]] = None,
812821
dry_run: bool = False,
813822
) -> Union[str, DryRunFileInfo]: ...
814823

@@ -832,6 +841,7 @@ def hf_hub_download(
832841
local_files_only: bool = False,
833842
headers: Optional[dict[str, str]] = None,
834843
endpoint: Optional[str] = None,
844+
tqdm_class: Optional[type[base_tqdm]] = None,
835845
dry_run: bool = False,
836846
) -> Union[str, DryRunFileInfo]:
837847
"""Download a given file if it's not already present in the local cache.
@@ -908,6 +918,11 @@ def hf_hub_download(
908918
local cached file if it exists.
909919
headers (`dict`, *optional*):
910920
Additional headers to be sent with the request.
921+
tqdm_class (`tqdm`, *optional*):
922+
If provided, overwrites the default behavior for the progress bar. Passed
923+
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
924+
Defaults to the custom HF progress bar that can be disabled by setting
925+
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
911926
dry_run (`bool`, *optional*, defaults to `False`):
912927
If `True`, perform a dry run without actually downloading the file. Returns a
913928
[`DryRunFileInfo`] object containing information about what would be downloaded.
@@ -985,6 +1000,7 @@ def hf_hub_download(
9851000
cache_dir=cache_dir,
9861001
force_download=force_download,
9871002
local_files_only=local_files_only,
1003+
tqdm_class=tqdm_class,
9881004
dry_run=dry_run,
9891005
)
9901006
else:
@@ -1004,6 +1020,7 @@ def hf_hub_download(
10041020
# Additional options
10051021
local_files_only=local_files_only,
10061022
force_download=force_download,
1023+
tqdm_class=tqdm_class,
10071024
dry_run=dry_run,
10081025
)
10091026

@@ -1025,6 +1042,7 @@ def _hf_hub_download_to_cache_dir(
10251042
# Additional options
10261043
local_files_only: bool,
10271044
force_download: bool,
1045+
tqdm_class: Optional[type[base_tqdm]],
10281046
dry_run: bool,
10291047
) -> Union[str, DryRunFileInfo]:
10301048
"""Download a given file to a cache folder, if not already present.
@@ -1189,6 +1207,7 @@ def _hf_hub_download_to_cache_dir(
11891207
force_download=force_download,
11901208
etag=etag,
11911209
xet_file_data=xet_file_data,
1210+
tqdm_class=tqdm_class,
11921211
)
11931212
if not os.path.exists(pointer_path):
11941213
_create_symlink(blob_path, pointer_path, new_blob=True)
@@ -1214,6 +1233,7 @@ def _hf_hub_download_to_local_dir(
12141233
cache_dir: str,
12151234
force_download: bool,
12161235
local_files_only: bool,
1236+
tqdm_class: Optional[type[base_tqdm]],
12171237
dry_run: bool,
12181238
) -> Union[str, DryRunFileInfo]:
12191239
"""Download a given file to a local folder, if not already present.
@@ -1377,6 +1397,7 @@ def _hf_hub_download_to_local_dir(
13771397
force_download=force_download,
13781398
etag=etag,
13791399
xet_file_data=xet_file_data,
1400+
tqdm_class=tqdm_class,
13801401
)
13811402

13821403
write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
@@ -1727,6 +1748,7 @@ def _download_to_tmp_and_move(
17271748
force_download: bool,
17281749
etag: Optional[str],
17291750
xet_file_data: Optional[XetFileData],
1751+
tqdm_class: Optional[type[base_tqdm]] = None,
17301752
) -> None:
17311753
"""Download content from a URL to a destination path.
17321754
@@ -1749,15 +1771,15 @@ def _download_to_tmp_and_move(
17491771
# By default, we will try to resume the download if possible.
17501772
# However, if the user has set `force_download=True`, then we should
17511773
# not resume the download => delete the incomplete file.
1752-
logger.info(f"Removing incomplete file '{incomplete_path}' (force_download=True)")
1774+
logger.debug(f"Removing incomplete file '{incomplete_path}' (force_download=True)")
17531775
incomplete_path.unlink(missing_ok=True)
17541776

17551777
with incomplete_path.open("ab") as f:
17561778
resume_size = f.tell()
17571779
message = f"Downloading '{filename}' to '{incomplete_path}'"
17581780
if resume_size > 0 and expected_size is not None:
17591781
message += f" (resume from {resume_size}/{expected_size})"
1760-
logger.info(message)
1782+
logger.debug(message)
17611783

17621784
if expected_size is not None: # might be None if HTTP header not set correctly
17631785
# Check disk space in both tmp and destination path
@@ -1772,6 +1794,7 @@ def _download_to_tmp_and_move(
17721794
headers=headers,
17731795
expected_size=expected_size,
17741796
displayed_filename=filename,
1797+
tqdm_class=tqdm_class,
17751798
)
17761799
else:
17771800
if xet_file_data is not None and not constants.HF_HUB_DISABLE_XET:
@@ -1787,9 +1810,10 @@ def _download_to_tmp_and_move(
17871810
resume_size=resume_size,
17881811
headers=headers,
17891812
expected_size=expected_size,
1813+
tqdm_class=tqdm_class,
17901814
)
17911815

1792-
logger.info(f"Download complete. Moving file to {destination_path}")
1816+
logger.debug(f"Download complete. Moving file to {destination_path}")
17931817
_chmod_and_move(incomplete_path, destination_path)
17941818

17951819

src/huggingface_hub/utils/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,11 @@
119119
parse_xet_file_data_from_response,
120120
refresh_xet_connection_info,
121121
)
122-
from .tqdm import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm, tqdm_stream_file
122+
from .tqdm import (
123+
are_progress_bars_disabled,
124+
disable_progress_bars,
125+
enable_progress_bars,
126+
is_tqdm_disabled,
127+
tqdm,
128+
tqdm_stream_file,
129+
)

src/huggingface_hub/utils/tqdm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def _get_progress_bar_context(
288288
unit: str = "B",
289289
unit_scale: bool = True,
290290
name: Optional[str] = None,
291+
tqdm_class: Optional[type[old_tqdm]] = None,
291292
_tqdm_bar: Optional[tqdm] = None,
292293
) -> ContextManager[tqdm]:
293294
if _tqdm_bar is not None:
@@ -296,7 +297,7 @@ def _get_progress_bar_context(
296297
# Makes it easier to use the same code path for both cases but in the later
297298
# case, the progress bar is not closed when exiting the context manager.
298299

299-
return tqdm(
300+
return (tqdm_class or tqdm)( # type: ignore[return-value]
300301
unit=unit,
301302
unit_scale=unit_scale,
302303
total=total,

0 commit comments

Comments
 (0)