Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1057,4 +1057,8 @@
- local: internal/time_series_utils
title: Utilities for Time Series
title: Internal helpers
- sections:
- local: reference/environment_variables
title: Environment Variables
title: Reference
title: API
67 changes: 67 additions & 0 deletions docs/source/en/reference/environment_variables.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Environment Variables

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a warning as this should not really be ran on login nodes or nodes that share bandwith, at the risk of having a lot of anger from your colleagues ! 🤣

Suggested change
- Debugging issues involving hf_transfer can be extremely difficult, making maintenance and issue resolution frustrating.
- User experience may be slightly degraded, with a more fragmented progress bar, corner cases when using Ctrl+C, and the overhead of launching a subprocess.
- hf_transfer only provides a speed boost if the machine has sufficient bandwidth—otherwise, performance remains the same or may even regress.
- Spawning a subprocess on each CPU core can heavily impact system performance, potentially freezing or significantly slowing down laptops.
- Some features are not supported, including resumable downloads, proxies, and others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True!

The intention behind having this available is for primarily anyone working with HF on the cloud.

Copy link
Contributor Author

@inf3rnus inf3rnus Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other note, we're using threads to avoid the painful overhead of python multiprocessing 😄 (takes an additional 4-6s to instantiate them vs. threads), but the laptop being totally overwhelmed by the processing load will indeed happen for smaller machines!

## HF_ENABLE_PARALLEL_DOWNLOADING

By default this is disabled. Enables the parallel downloading of models with sharded weight files. Can decrease the time to load large models significantly, often times producing _speed ups of greater than 50%_.

Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] = "true"`

While downloading is already parallelized at the file level when `HF_HUB_ENABLE_HF_TRANSFER` is enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` parallelizes the number of files that can be concurrently downloaded. Which can greatly speed up downloads if the machine you're using can handle it in terms of network and IO bandwidth.

e.g. here's a comparison for `facebook/opt-30b` on an AWS EC2 `g4dn.metal`:

- `HF_HUB_ENABLE_HF_TRANSFER` enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` disabled

- ~45s download

- `HF_HUB_ENABLE_HF_TRANSFER` enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` enabled
- ~12s download

To fully saturate a machine capable of massive network bandwidth, set `HF_ENABLE_PARALLEL_DOWNLOADING="true"` and `HF_HUB_ENABLE_HF_TRANSFER="1"`

_Note, you will want to profile your code before committing to using this environment variable, this will not produce speed ups for smaller models._

```py
import os
from transformers import pipeline

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" # enable parallized pool of downloader threads

model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```

## HF_PARALLEL_DOWNLOADING_WORKERS

Determines how many threads should be used when the parallel downloading of model weight shards is enabled.

Default is `8` although less may run if the number of shard files is less than the number of workers. i.e. it takes the min() of HF_PARALLEL_DOWNLOADING_WORKERS and the number of shard files to download.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that can be automatized !˜

Copy link
Contributor Author

@inf3rnus inf3rnus Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is absolutely true! With that said, I feel like if this PR were to be merged that should be an itr 2 enhancement.

Argument being that this is something people are going to need to likely tune anyway to squeeze every last drop out of it, although there's probably a simple sane heuristic we could use 🤔


e.g. if there are 2 shard files, but 8 workers are specified, only two workers will spawn.

```py
import os
from transformers import pipeline

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" # enable parallized pool of downloader threads
os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] = "12" # Specify a non default number of workers

model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```
100 changes: 94 additions & 6 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import json
import os
import queue
import re
import sys
import tempfile
import threading
import warnings
from concurrent import futures
from pathlib import Path
Expand Down Expand Up @@ -1108,15 +1110,101 @@ def get_checkpoint_shard_files(
sharded_metadata["weight_map"] = index["weight_map"].copy()

# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
if checkpoint_exists(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata

# At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache,
# or download the files
cached_filenames = cached_files(
args_list = [
(
pretrained_model_name_or_path,
shard_filename,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
user_agent,
revision,
subfolder,
_commit_hash,
)
for shard_filename in shard_filenames
]

cached_filenames = []

if json.loads(os.environ.get("HF_ENABLE_PARALLEL_DOWNLOADING", "false")):
num_workers = json.loads(os.environ.get("HF_PARALLEL_DOWNLOADING_WORKERS", "8"))

# make sure you don't have excessive workers
num_workers = min(len(args_list), num_workers)

print(f"Downloading model weights in parallel with {num_workers} workers...")

cached_filenames += download_shards_with_threads(num_workers, args_list)

# reorder after the out of order execution that threads will produce
cached_filenames.sort()
else:
for args in args_list:
cached_filename = download_shard(args)
cached_filenames.append(cached_filename)

return cached_filenames, sharded_metadata


# NOTE makes testing easier to control get_checkpoint_shard_files when testing via a monkey patch
def checkpoint_exists(pretrained_model_name_or_path):
return os.path.isdir(pretrained_model_name_or_path)


def worker(q, cached_filenames):
while not q.empty():
args = q.get()
filename = download_shard(args)
cached_filenames.append(filename)
q.task_done()


def download_shards_with_threads(num_workers, args_list):
q = queue.Queue()
cached_filenames = []

for args in args_list:
q.put(args)

threads = []
for _ in range(num_workers):
t = threading.Thread(target=worker, args=(q, cached_filenames))
t.start()
threads.append(t)

for t in threads:
t.join()

return cached_filenames


def download_shard(args):
(
pretrained_model_name_or_path,
shard_filename,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
user_agent,
revision,
subfolder,
_commit_hash,
) = args

cached_filename = cached_file(
pretrained_model_name_or_path,
shard_filenames,
shard_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand All @@ -1129,7 +1217,7 @@ def get_checkpoint_shard_files(
_commit_hash=_commit_hash,
)

return cached_filenames, sharded_metadata
return cached_filename


def create_and_tag_model_card(
Expand Down
104 changes: 104 additions & 0 deletions tests/utils/test_hub_utils_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import unittest

from huggingface_hub import hf_hub_download

import transformers
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
TRANSFORMERS_CACHE,
)
from transformers.utils.hub import get_checkpoint_shard_files


RANDOM_BERT_SHARDED = "hf-internal-testing/tiny-random-bert-sharded"
CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert-sharded")
FULL_COMMIT_HASH = "04a52fc6ff50bf21639d65be441bd2bd8410ef5d"

CHECK_POINT_EXISTS_FUNC = transformers.utils.hub.checkpoint_exists


class GetFromCacheTestsParallel(unittest.TestCase):
def setUp(self) -> None:
if os.path.exists(CACHE_DIR):
shutil.rmtree(CACHE_DIR)

# NOTE mock checkpoint_exists so it's a function that returns False
self._original_checkpoint_exists = CHECK_POINT_EXISTS_FUNC
# Mock to always make it return False
transformers.utils.hub.checkpoint_exists = lambda *args, **kwargs: False

os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] = "true"
os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] = "8"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

def tearDown(self) -> None:
# Restore the original function after the test
transformers.utils.hub.checkpoint_exists = self._original_checkpoint_exists

del os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"]
del os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"]
del os.environ["HF_HUB_ENABLE_HF_TRANSFER"]

def test_get_checkpoint_shard_files(self):
hf_hub_download(
RANDOM_BERT_SHARDED,
filename=SAFE_WEIGHTS_INDEX_NAME,
)

index_filename = os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, SAFE_WEIGHTS_INDEX_NAME)

cached_filenames, sharded_metadata = get_checkpoint_shard_files(
RANDOM_BERT_SHARDED, index_filename, revision=FULL_COMMIT_HASH
)

# Should have downloaded the file in here
self.assertTrue(os.path.isdir(CACHE_DIR))

# make sure the files we were supposed to download were downloaded
with open(index_filename, "r") as f:
index = json.loads(f.read())

weight_map_file_names = sorted(set(index["weight_map"].values()))

# make sure we have the same number of caches files as the number of files in the weight map
self.assertTrue(len(weight_map_file_names), len(cached_filenames))

for index, cached_filename in enumerate(cached_filenames):
# now make sure each file exists
exists = os.path.exists(cached_filename)
self.assertTrue(exists)

# now make sure each file was in the set of files we told the function to download
filename = cached_filename.split("/").pop()

# make sure they are both sorted the same way
name_in_set = weight_map_file_names[index]
self.assertTrue(name_in_set == filename)

# for extra safety we now perform an integration test for the cached data
model = transformers.AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-sharded",
)
self.assertIsNotNone(model)

def test_get_checkpoint_shard_files_integration(self):
model = transformers.AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-sharded",
)
self.assertIsNotNone(model)