diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 79f8eb3d490d..3c390d1d52c9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1057,4 +1057,8 @@ - local: internal/time_series_utils title: Utilities for Time Series title: Internal helpers + - sections: + - local: reference/environment_variables + title: Environment Variables + title: Reference title: API diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md new file mode 100644 index 000000000000..21c7cc42c23d --- /dev/null +++ b/docs/source/en/reference/environment_variables.md @@ -0,0 +1,67 @@ + + +# Environment Variables + +## HF_ENABLE_PARALLEL_DOWNLOADING + +By default this is disabled. Enables the parallel downloading of models with sharded weight files. Can decrease the time to load large models significantly, often times producing _speed ups of greater than 50%_. + +Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] = "true"` + +While downloading is already parallelized at the file level when `HF_HUB_ENABLE_HF_TRANSFER` is enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` parallelizes the number of files that can be concurrently downloaded. Which can greatly speed up downloads if the machine you're using can handle it in terms of network and IO bandwidth. + +e.g. here's a comparison for `facebook/opt-30b` on an AWS EC2 `g4dn.metal`: + +- `HF_HUB_ENABLE_HF_TRANSFER` enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` disabled + + - ~45s download + +- `HF_HUB_ENABLE_HF_TRANSFER` enabled, `HF_ENABLE_PARALLEL_DOWNLOADING` enabled + - ~12s download + +To fully saturate a machine capable of massive network bandwidth, set `HF_ENABLE_PARALLEL_DOWNLOADING="true"` and `HF_HUB_ENABLE_HF_TRANSFER="1"` + +_Note, you will want to profile your code before committing to using this environment variable, this will not produce speed ups for smaller models._ + +```py +import os +from transformers import pipeline + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" # enable parallized pool of downloader threads + +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` + +## HF_PARALLEL_DOWNLOADING_WORKERS + +Determines how many threads should be used when the parallel downloading of model weight shards is enabled. + +Default is `8` although less may run if the number of shard files is less than the number of workers. i.e. it takes the min() of HF_PARALLEL_DOWNLOADING_WORKERS and the number of shard files to download. + +e.g. if there are 2 shard files, but 8 workers are specified, only two workers will spawn. + +```py +import os +from transformers import pipeline + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" # enable parallized pool of downloader threads +os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] = "12" # Specify a non default number of workers + +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 01d19c214053..67549bda9b9b 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -17,9 +17,11 @@ import json import os +import queue import re import sys import tempfile +import threading import warnings from concurrent import futures from pathlib import Path @@ -1108,15 +1110,101 @@ def get_checkpoint_shard_files( sharded_metadata["weight_map"] = index["weight_map"].copy() # First, let's deal with local folder. - if os.path.isdir(pretrained_model_name_or_path): + if checkpoint_exists(pretrained_model_name_or_path): shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames] return shard_filenames, sharded_metadata - # At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache, - # or download the files - cached_filenames = cached_files( + args_list = [ + ( + pretrained_model_name_or_path, + shard_filename, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + token, + user_agent, + revision, + subfolder, + _commit_hash, + ) + for shard_filename in shard_filenames + ] + + cached_filenames = [] + + if json.loads(os.environ.get("HF_ENABLE_PARALLEL_DOWNLOADING", "false")): + num_workers = json.loads(os.environ.get("HF_PARALLEL_DOWNLOADING_WORKERS", "8")) + + # make sure you don't have excessive workers + num_workers = min(len(args_list), num_workers) + + print(f"Downloading model weights in parallel with {num_workers} workers...") + + cached_filenames += download_shards_with_threads(num_workers, args_list) + + # reorder after the out of order execution that threads will produce + cached_filenames.sort() + else: + for args in args_list: + cached_filename = download_shard(args) + cached_filenames.append(cached_filename) + + return cached_filenames, sharded_metadata + + +# NOTE makes testing easier to control get_checkpoint_shard_files when testing via a monkey patch +def checkpoint_exists(pretrained_model_name_or_path): + return os.path.isdir(pretrained_model_name_or_path) + + +def worker(q, cached_filenames): + while not q.empty(): + args = q.get() + filename = download_shard(args) + cached_filenames.append(filename) + q.task_done() + + +def download_shards_with_threads(num_workers, args_list): + q = queue.Queue() + cached_filenames = [] + + for args in args_list: + q.put(args) + + threads = [] + for _ in range(num_workers): + t = threading.Thread(target=worker, args=(q, cached_filenames)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + return cached_filenames + + +def download_shard(args): + ( + pretrained_model_name_or_path, + shard_filename, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + token, + user_agent, + revision, + subfolder, + _commit_hash, + ) = args + + cached_filename = cached_file( pretrained_model_name_or_path, - shard_filenames, + shard_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1129,7 +1217,7 @@ def get_checkpoint_shard_files( _commit_hash=_commit_hash, ) - return cached_filenames, sharded_metadata + return cached_filename def create_and_tag_model_card( diff --git a/tests/utils/test_hub_utils_parallel.py b/tests/utils/test_hub_utils_parallel.py new file mode 100644 index 000000000000..a85f34053794 --- /dev/null +++ b/tests/utils/test_hub_utils_parallel.py @@ -0,0 +1,104 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import shutil +import unittest + +from huggingface_hub import hf_hub_download + +import transformers +from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + TRANSFORMERS_CACHE, +) +from transformers.utils.hub import get_checkpoint_shard_files + + +RANDOM_BERT_SHARDED = "hf-internal-testing/tiny-random-bert-sharded" +CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert-sharded") +FULL_COMMIT_HASH = "04a52fc6ff50bf21639d65be441bd2bd8410ef5d" + +CHECK_POINT_EXISTS_FUNC = transformers.utils.hub.checkpoint_exists + + +class GetFromCacheTestsParallel(unittest.TestCase): + def setUp(self) -> None: + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + + # NOTE mock checkpoint_exists so it's a function that returns False + self._original_checkpoint_exists = CHECK_POINT_EXISTS_FUNC + # Mock to always make it return False + transformers.utils.hub.checkpoint_exists = lambda *args, **kwargs: False + + os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] = "true" + os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] = "8" + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + + def tearDown(self) -> None: + # Restore the original function after the test + transformers.utils.hub.checkpoint_exists = self._original_checkpoint_exists + + del os.environ["HF_ENABLE_PARALLEL_DOWNLOADING"] + del os.environ["HF_PARALLEL_DOWNLOADING_WORKERS"] + del os.environ["HF_HUB_ENABLE_HF_TRANSFER"] + + def test_get_checkpoint_shard_files(self): + hf_hub_download( + RANDOM_BERT_SHARDED, + filename=SAFE_WEIGHTS_INDEX_NAME, + ) + + index_filename = os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, SAFE_WEIGHTS_INDEX_NAME) + + cached_filenames, sharded_metadata = get_checkpoint_shard_files( + RANDOM_BERT_SHARDED, index_filename, revision=FULL_COMMIT_HASH + ) + + # Should have downloaded the file in here + self.assertTrue(os.path.isdir(CACHE_DIR)) + + # make sure the files we were supposed to download were downloaded + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + weight_map_file_names = sorted(set(index["weight_map"].values())) + + # make sure we have the same number of caches files as the number of files in the weight map + self.assertTrue(len(weight_map_file_names), len(cached_filenames)) + + for index, cached_filename in enumerate(cached_filenames): + # now make sure each file exists + exists = os.path.exists(cached_filename) + self.assertTrue(exists) + + # now make sure each file was in the set of files we told the function to download + filename = cached_filename.split("/").pop() + + # make sure they are both sorted the same way + name_in_set = weight_map_file_names[index] + self.assertTrue(name_in_set == filename) + + # for extra safety we now perform an integration test for the cached data + model = transformers.AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-bert-sharded", + ) + self.assertIsNotNone(model) + + def test_get_checkpoint_shard_files_integration(self): + model = transformers.AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-bert-sharded", + ) + self.assertIsNotNone(model)