@@ -37,8 +37,9 @@ pub struct Backend {
3737}
3838
3939impl Backend {
40- pub fn new (
40+ pub async fn new (
4141 model_path : PathBuf ,
42+ api_repo : Option < ApiRepo > ,
4243 dtype : DType ,
4344 model_type : ModelType ,
4445 uds_path : String ,
@@ -49,12 +50,14 @@ impl Backend {
4950
5051 let backend = init_backend (
5152 model_path,
53+ api_repo,
5254 dtype,
5355 model_type. clone ( ) ,
5456 uds_path,
5557 otlp_endpoint,
5658 otlp_service_name,
57- ) ?;
59+ )
60+ . await ?;
5861 let padded_model = backend. is_padded ( ) ;
5962 let max_batch_size = backend. max_batch_size ( ) ;
6063
@@ -193,48 +196,102 @@ impl Backend {
193196}
194197
195198#[ allow( unused) ]
196- fn init_backend (
199+ async fn init_backend (
197200 model_path : PathBuf ,
201+ api_repo : Option < ApiRepo > ,
198202 dtype : DType ,
199203 model_type : ModelType ,
200204 uds_path : String ,
201205 otlp_endpoint : Option < String > ,
202206 otlp_service_name : String ,
203207) -> Result < Box < dyn CoreBackend + Send > , BackendError > {
208+ let mut backend_start_failed = false ;
209+
210+ if cfg ! ( feature = "ort" ) {
211+ #[ cfg( feature = "ort" ) ]
212+ {
213+ if let Some ( api_repo) = api_repo. as_ref ( ) {
214+ let start = std:: time:: Instant :: now ( ) ;
215+ download_onnx ( api_repo)
216+ . await
217+ . map_err ( |err| BackendError :: WeightsNotFound ( err. to_string ( ) ) ) ;
218+ tracing:: info!( "Model ONNX weights downloaded in {:?}" , start. elapsed( ) ) ;
219+ }
220+
221+ let backend = OrtBackend :: new ( & model_path, dtype. to_string ( ) , model_type. clone ( ) ) ;
222+ match backend {
223+ Ok ( b) => return Ok ( Box :: new ( b) ) ,
224+ Err ( err) => {
225+ tracing:: error!( "Could not start ORT backend: {err}" ) ;
226+ backend_start_failed = true ;
227+ }
228+ }
229+ }
230+ }
231+
232+ if let Some ( api_repo) = api_repo. as_ref ( ) {
233+ if cfg ! ( feature = "python" ) || cfg ! ( feature = "candle" ) {
234+ let start = std:: time:: Instant :: now ( ) ;
235+ if download_safetensors ( api_repo) . await . is_err ( ) {
236+ tracing:: warn!( "safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower." ) ;
237+ tracing:: info!( "Downloading `pytorch_model.bin`" ) ;
238+ api_repo
239+ . get ( "pytorch_model.bin" )
240+ . await
241+ . map_err ( |err| BackendError :: WeightsNotFound ( err. to_string ( ) ) ) ?;
242+ }
243+
244+ tracing:: info!( "Model weights downloaded in {:?}" , start. elapsed( ) ) ;
245+ }
246+ }
247+
204248 if cfg ! ( feature = "candle" ) {
205249 #[ cfg( feature = "candle" ) ]
206- return Ok ( Box :: new ( CandleBackend :: new (
207- model_path,
208- dtype. to_string ( ) ,
209- model_type,
210- ) ?) ) ;
211- } else if cfg ! ( feature = "python" ) {
250+ {
251+ let backend = CandleBackend :: new ( & model_path, dtype. to_string ( ) , model_type. clone ( ) ) ;
252+ match backend {
253+ Ok ( b) => return Ok ( Box :: new ( b) ) ,
254+ Err ( err) => {
255+ tracing:: error!( "Could not start Candle backend: {err}" ) ;
256+ backend_start_failed = true ;
257+ }
258+ }
259+ }
260+ }
261+
262+ if cfg ! ( feature = "python" ) {
212263 #[ cfg( feature = "python" ) ]
213264 {
214- return Ok ( Box :: new (
215- std:: thread:: spawn ( move || {
216- PythonBackend :: new (
217- model_path. to_str ( ) . unwrap ( ) . to_string ( ) ,
218- dtype. to_string ( ) ,
219- model_type,
220- uds_path,
221- otlp_endpoint,
222- otlp_service_name,
223- )
224- } )
225- . join ( )
226- . expect ( "Python Backend management thread failed" ) ?,
227- ) ) ;
265+ let backend = std:: thread:: spawn ( move || {
266+ PythonBackend :: new (
267+ model_path. to_str ( ) . unwrap ( ) . to_string ( ) ,
268+ dtype. to_string ( ) ,
269+ model_type,
270+ uds_path,
271+ otlp_endpoint,
272+ otlp_service_name,
273+ )
274+ } )
275+ . join ( )
276+ . expect ( "Python Backend management thread failed" ) ;
277+
278+ match backend {
279+ Ok ( b) => return Ok ( Box :: new ( b) ) ,
280+ Err ( err) => {
281+ tracing:: error!( "Could not start Python backend: {err}" ) ;
282+ backend_start_failed = true ;
283+ }
284+ }
228285 }
229- } else if cfg ! ( feature = "ort" ) {
230- #[ cfg( feature = "ort" ) ]
231- return Ok ( Box :: new ( OrtBackend :: new (
232- model_path,
233- dtype. to_string ( ) ,
234- model_type,
235- ) ?) ) ;
236286 }
237- Err ( BackendError :: NoBackend )
287+
288+ if backend_start_failed {
289+ Err ( BackendError :: Start (
290+ "Could not start a suitable backend" . to_string ( ) ,
291+ ) )
292+ } else {
293+ Err ( BackendError :: NoBackend )
294+ }
238295}
239296
240297#[ derive( Debug ) ]
@@ -298,31 +355,6 @@ enum BackendCommand {
298355 ) ,
299356}
300357
301- pub async fn download_weights ( api : & ApiRepo ) -> Result < Vec < PathBuf > , ApiError > {
302- let model_files = if cfg ! ( feature = "python" ) || cfg ! ( feature = "candle" ) {
303- match download_safetensors ( api) . await {
304- Ok ( p) => p,
305- Err ( _) => {
306- tracing:: warn!( "safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower." ) ;
307- tracing:: info!( "Downloading `pytorch_model.bin`" ) ;
308- let p = api. get ( "pytorch_model.bin" ) . await ?;
309- vec ! [ p]
310- }
311- }
312- } else if cfg ! ( feature = "ort" ) {
313- match download_onnx ( api) . await {
314- Ok ( p) => p,
315- Err ( err) => {
316- panic ! ( "failed to download `model.onnx` or `model.onnx_data`. Check the onnx file exists in the repository. {err}" ) ;
317- }
318- }
319- } else {
320- unreachable ! ( )
321- } ;
322-
323- Ok ( model_files)
324- }
325-
326358async fn download_safetensors ( api : & ApiRepo ) -> Result < Vec < PathBuf > , ApiError > {
327359 // Single file
328360 tracing:: info!( "Downloading `model.safetensors`" ) ;
@@ -362,6 +394,7 @@ async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
362394 Ok ( safetensors_files)
363395}
364396
397+ #[ cfg( feature = "ort" ) ]
365398async fn download_onnx ( api : & ApiRepo ) -> Result < Vec < PathBuf > , ApiError > {
366399 let mut model_files: Vec < PathBuf > = Vec :: new ( ) ;
367400
0 commit comments