@@ -397,6 +397,9 @@ class Model(Component, PredictMixin):
397397 #: The method to use for prediction (optional)
398398 predict_method : t .Optional [str ] = None
399399
400+ #: The method to transfer the model to a device
401+ model_to_device_method : t .Optional [str ] = None
402+
400403 #: Whether to batch predict (optional)
401404 batch_predict : bool = False
402405
@@ -413,6 +416,9 @@ class Model(Component, PredictMixin):
413416 future : t .Optional [Future ] = None
414417 device : str = "cpu"
415418
419+ # TODO: handle situation with multiple GPUs
420+ preferred_devices : t .Sequence [str ] = ("cuda" , "mps" , "cpu" )
421+
416422 artifacts : t .ClassVar [t .Sequence [str ]] = ['object' ]
417423
418424 type_id : t .ClassVar [str ] = 'model'
@@ -429,6 +435,21 @@ def __post_init__(self):
429435 else :
430436 self .to_call = getattr (self .object .artifact , self .predict_method )
431437
438+ self .artifact_to_method = None
439+ if self .model_to_device_method is not None :
440+ self .artifact_to_method = getattr (self , self .model_to_device_method )
441+
442+ def on_load (self , db : DB ) -> None :
443+ if self .artifact_to_method :
444+ for i , device in enumerate (self .preferred_devices ):
445+ try :
446+ self .artifact_to_method (device )
447+ self .device = device
448+ return
449+ except Exception :
450+ if i == len (self .preferred_devices ) - 1 :
451+ raise
452+
432453 @property
433454 def child_components (self ) -> t .Sequence [t .Tuple [str , str ]]:
434455 out = []
0 commit comments