Skip to content

Commit 9aee8e9

Browse files
kartik4949thejumpman2323
authored andcommitted
Device fallback on model load
Former-commit-id: 5549f74
1 parent fecaaf9 commit 9aee8e9

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

superduperdb/container/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ class Model(Component, PredictMixin):
397397
#: The method to use for prediction (optional)
398398
predict_method: t.Optional[str] = None
399399

400+
#: The method to transfer the model to a device
401+
model_to_device_method: t.Optional[str] = None
402+
400403
#: Whether to batch predict (optional)
401404
batch_predict: bool = False
402405

@@ -413,6 +416,9 @@ class Model(Component, PredictMixin):
413416
future: t.Optional[Future] = None
414417
device: str = "cpu"
415418

419+
# TODO: handle situation with multiple GPUs
420+
preferred_devices: t.Sequence[str] = ("cuda", "mps", "cpu")
421+
416422
artifacts: t.ClassVar[t.Sequence[str]] = ['object']
417423

418424
type_id: t.ClassVar[str] = 'model'
@@ -429,6 +435,21 @@ def __post_init__(self):
429435
else:
430436
self.to_call = getattr(self.object.artifact, self.predict_method)
431437

438+
self.artifact_to_method = None
439+
if self.model_to_device_method is not None:
440+
self.artifact_to_method = getattr(self, self.model_to_device_method)
441+
442+
def on_load(self, db: DB) -> None:
443+
if self.artifact_to_method:
444+
for i, device in enumerate(self.preferred_devices):
445+
try:
446+
self.artifact_to_method(device)
447+
self.device = device
448+
return
449+
except Exception:
450+
if i == len(self.preferred_devices) - 1:
451+
raise
452+
432453
@property
433454
def child_components(self) -> t.Sequence[t.Tuple[str, str]]:
434455
out = []

superduperdb/ext/torch/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ class TorchModel(Base, Model): # type: ignore[misc]
364364
train_forward_method: str = '__call__'
365365

366366
def __post_init__(self):
367+
self.model_to_device_method = 'move_to_device'
368+
367369
super().__post_init__()
368370

369371
self.object.serializer = 'torch'
@@ -404,6 +406,9 @@ def parameters(self):
404406
def state_dict(self):
405407
return self.object.state_dict()
406408

409+
def move_to_device(self, device):
410+
self.object.artifact.to(device)
411+
407412
@contextmanager
408413
def saving(self):
409414
with super().saving():
@@ -476,6 +481,10 @@ def func(x):
476481
return out
477482

478483
def train_forward(self, X, y=None):
484+
X = X.to(self.device)
485+
if y is not None:
486+
y = y.to(self.device)
487+
479488
method = getattr(self.object.artifact, self.train_forward_method)
480489
if hasattr(self.object.artifact, 'train_forward'):
481490
if y is None:

0 commit comments

Comments
 (0)