1313from requests .exceptions import HTTPError
1414from transformers import AutoTokenizer , PreTrainedTokenizer , PreTrainedTokenizerFast
1515
16- from QEfficient .utils .constants import QEFF_MODELS_DIR
16+ from QEfficient .utils .constants import QEFF_MODELS_DIR , Constants
1717from QEfficient .utils .logging_utils import logger
1818
1919
20+ class DownloadRetryLimitExceeded (Exception ):
21+ """
22+ Used for raising error when hf_download fails to download the model after given max_retries.
23+ """
24+
25+
2026def login_and_download_hf_lm (model_name , * args , ** kwargs ):
2127 logger .info (f"loading HuggingFace model for { model_name } " )
2228 hf_token = kwargs .pop ("hf_token" , None )
@@ -37,12 +43,12 @@ def hf_download(
3743 hf_token : Optional [str ] = None ,
3844 allow_patterns : Optional [List [str ]] = None ,
3945 ignore_patterns : Optional [List [str ]] = None ,
46+ max_retries : Optional [int ] = Constants .MAX_RETRIES ,
4047):
4148 # Setup cache_dir
4249 if cache_dir is not None :
4350 os .makedirs (cache_dir , exist_ok = True )
4451
45- max_retries = 5
4652 retry_count = 0
4753 while retry_count < max_retries :
4854 try :
@@ -59,14 +65,23 @@ def hf_download(
5965 except requests .ReadTimeout as e :
6066 logger .info (f"Read timeout: { e } " )
6167 retry_count += 1
62-
6368 except HTTPError as e :
64- retry_count = max_retries
6569 if e .response .status_code == 401 :
6670 logger .info ("You need to pass a valid `--hf_token=...` to download private checkpoints." )
71+ raise e
72+ except OSError as e :
73+ if "Consistency check failed" in str (e ):
74+ logger .info (
75+ "Consistency check failed during model download. The file appears to be incomplete. Resuming the download..."
76+ )
77+ retry_count += 1
6778 else :
6879 raise e
6980
81+ if retry_count >= max_retries :
82+ raise DownloadRetryLimitExceeded (
83+ f"Unable to download full model after { max_retries } tries. If the model fileS are huge in size, please try again."
84+ )
7085 return model_path
7186
7287
0 commit comments