-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Improve Model Download Speeds By ~3x For Large Models #36870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
<!--Copyright 2020 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
# Environment Variables | ||
|
||
## HF_ENABLE_PARALLEL_DOWNLOADING | ||
|
||
By default this is disabled. Enables the parallel downloading of models with sharded weight files. Can decrease the time to load large models significantly, often times producing _speed ups of greater than 50%_. | ||
|
||
Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] = "true"` | ||
|
||
While downloading is already parallelized at the file level when `HF_HUB_ENABLE_HF_TRANSFER` is enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` parallelizes the number of files that can be concurrently downloaded. Which can greatly speed up downloads if the machine you're using can handle it in terms of network and IO bandwidth. | ||
|
||
e.g. here's a comparison for `facebook/opt-30b` on an AWS EC2 `g4dn.metal`: | ||
|
||
- `HF_HUB_ENABLE_HF_TRANSFER` enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` disabled | ||
|
||
- ~45s download | ||
|
||
- `HF_HUB_ENABLE_HF_TRANSFER` enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` enabled | ||
- ~12s download | ||
|
||
To fully saturate a machine capable of massive network bandwidth, set `HF_ENABLE_PARALLEL_DOWNLOADING="true"` and `HF_HUB_ENABLE_HF_TRANSFER="1"` | ||
|
||
_Note, you will want to profile your code before committing to using this environment variable, this will not produce speed ups for smaller models._ | ||
|
||
```py | ||
import os | ||
from transformers import pipeline | ||
|
||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | ||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" # enable parallized pool of downloader threads | ||
|
||
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") | ||
``` | ||
|
||
## HF_PARALLEL_DOWNLOADING_WORKERS | ||
|
||
Determines how many threads should be used when the parallel downloading of model weight shards is enabled. | ||
|
||
Default is `8` although less may run if the number of shard files is less than the number of workers. i.e. it takes the min() of HF_PARALLEL_DOWNLOADING_WORKERS and the number of shard files to download. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that can be automatized !˜ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is absolutely true! With that said, I feel like if this PR were to be merged that should be an itr 2 enhancement. Argument being that this is something people are going to need to likely tune anyway to squeeze every last drop out of it, although there's probably a simple sane heuristic we could use 🤔 |
||
|
||
e.g. if there are 2 shard files, but 8 workers are specified, only two workers will spawn. | ||
|
||
```py | ||
import os | ||
from transformers import pipeline | ||
|
||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | ||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" # enable parallized pool of downloader threads | ||
os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] = "12" # Specify a non default number of workers | ||
|
||
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright 2020 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import json | ||
import os | ||
import shutil | ||
import unittest | ||
|
||
from huggingface_hub import hf_hub_download | ||
|
||
import transformers | ||
from transformers.utils import ( | ||
SAFE_WEIGHTS_INDEX_NAME, | ||
TRANSFORMERS_CACHE, | ||
) | ||
from transformers.utils.hub import get_checkpoint_shard_files | ||
|
||
|
||
RANDOM_BERT_SHARDED = "hf-internal-testing/tiny-random-bert-sharded" | ||
CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert-sharded") | ||
FULL_COMMIT_HASH = "04a52fc6ff50bf21639d65be441bd2bd8410ef5d" | ||
|
||
CHECK_POINT_EXISTS_FUNC = transformers.utils.hub.checkpoint_exists | ||
|
||
|
||
class GetFromCacheTestsParallel(unittest.TestCase): | ||
def setUp(self) -> None: | ||
if os.path.exists(CACHE_DIR): | ||
shutil.rmtree(CACHE_DIR) | ||
|
||
# NOTE mock checkpoint_exists so it's a function that returns False | ||
self._original_checkpoint_exists = CHECK_POINT_EXISTS_FUNC | ||
# Mock to always make it return False | ||
transformers.utils.hub.checkpoint_exists = lambda *args, **kwargs: False | ||
|
||
os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] = "true" | ||
os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] = "8" | ||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | ||
|
||
def tearDown(self) -> None: | ||
# Restore the original function after the test | ||
transformers.utils.hub.checkpoint_exists = self._original_checkpoint_exists | ||
|
||
del os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] | ||
del os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] | ||
del os.environ["HF_HUB_ENABLE_HF_TRANSFER"] | ||
|
||
def test_get_checkpoint_shard_files(self): | ||
hf_hub_download( | ||
RANDOM_BERT_SHARDED, | ||
filename=SAFE_WEIGHTS_INDEX_NAME, | ||
) | ||
|
||
index_filename = os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, SAFE_WEIGHTS_INDEX_NAME) | ||
|
||
cached_filenames, sharded_metadata = get_checkpoint_shard_files( | ||
RANDOM_BERT_SHARDED, index_filename, revision=FULL_COMMIT_HASH | ||
) | ||
|
||
# Should have downloaded the file in here | ||
self.assertTrue(os.path.isdir(CACHE_DIR)) | ||
|
||
# make sure the files we were supposed to download were downloaded | ||
with open(index_filename, "r") as f: | ||
index = json.loads(f.read()) | ||
|
||
weight_map_file_names = sorted(set(index["weight_map"].values())) | ||
|
||
# make sure we have the same number of caches files as the number of files in the weight map | ||
self.assertTrue(len(weight_map_file_names), len(cached_filenames)) | ||
|
||
for index, cached_filename in enumerate(cached_filenames): | ||
# now make sure each file exists | ||
exists = os.path.exists(cached_filename) | ||
self.assertTrue(exists) | ||
|
||
# now make sure each file was in the set of files we told the function to download | ||
filename = cached_filename.split("/").pop() | ||
|
||
# make sure they are both sorted the same way | ||
name_in_set = weight_map_file_names[index] | ||
self.assertTrue(name_in_set == filename) | ||
|
||
# for extra safety we now perform an integration test for the cached data | ||
model = transformers.AutoModel.from_pretrained( | ||
"hf-internal-testing/tiny-random-bert-sharded", | ||
) | ||
self.assertIsNotNone(model) | ||
|
||
def test_get_checkpoint_shard_files_integration(self): | ||
model = transformers.AutoModel.from_pretrained( | ||
"hf-internal-testing/tiny-random-bert-sharded", | ||
) | ||
self.assertIsNotNone(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a warning as this should not really be ran on login nodes or nodes that share bandwith, at the risk of having a lot of anger from your colleagues ! 🤣
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True!
The intention behind having this available is for primarily anyone working with HF on the cloud.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One other note, we're using threads to avoid the painful overhead of python multiprocessing 😄 (takes an additional 4-6s to instantiate them vs. threads), but the laptop being totally overwhelmed by the processing load will indeed happen for smaller machines!