@@ -46,8 +46,11 @@ def get_model(
4646 batch_size = kwargs .get ("batch_size" , 0 ) # for transformers backend only
4747 max_concurrency = kwargs .get ("max_concurrency" , 100 ) # for http-client backend only
4848 http_timeout = kwargs .get ("http_timeout" , 600 ) # for http-client backend only
49+ server_headers = kwargs .get ("server_headers" , None ) # for http-client backend only
50+ max_retries = kwargs .get ("max_retries" , 3 ) # for http-client backend only
51+ retry_backoff_factor = kwargs .get ("retry_backoff_factor" , 0.5 ) # for http-client backend only
4952 # 从kwargs中移除这些参数,避免传递给不相关的初始化函数
50- for param in ["batch_size" , "max_concurrency" , "http_timeout" ]:
53+ for param in ["batch_size" , "max_concurrency" , "http_timeout" , "server_headers" , "max_retries" , "retry_backoff_factor" ]:
5154 if param in kwargs :
5255 del kwargs [param ]
5356 if backend not in ["http-client" ] and not model_path :
@@ -175,6 +178,9 @@ def get_model(
175178 batch_size = batch_size ,
176179 max_concurrency = max_concurrency ,
177180 http_timeout = http_timeout ,
181+ server_headers = server_headers ,
182+ max_retries = max_retries ,
183+ retry_backoff_factor = retry_backoff_factor ,
178184 )
179185 elapsed = round (time .time () - start_time , 2 )
180186 logger .info (f"get { backend } predictor cost: { elapsed } s" )
0 commit comments