3030from weakref import proxy
3131
3232import torch
33+ from lightning_utilities import module_available
3334from torch .optim import Optimizer
3435
3536import lightning .pytorch as pl
7071from lightning .pytorch .utilities .compile import _maybe_unwrap_optimized , _verify_strategy_supports_compile
7172from lightning .pytorch .utilities .exceptions import MisconfigurationException
7273from lightning .pytorch .utilities .model_helpers import is_overridden
74+ from lightning .pytorch .utilities .model_registry import _is_registry , download_model_from_registry
7375from lightning .pytorch .utilities .rank_zero import rank_zero_info , rank_zero_warn
7476from lightning .pytorch .utilities .seed import isolate_rng
7577from lightning .pytorch .utilities .types import (
@@ -128,6 +130,7 @@ def __init__(
128130 sync_batchnorm : bool = False ,
129131 reload_dataloaders_every_n_epochs : int = 0 ,
130132 default_root_dir : Optional [_PATH ] = None ,
133+ model_registry : Optional [str ] = None ,
131134 ) -> None :
132135 r"""Customize every aspect of training via flags.
133136
@@ -290,6 +293,8 @@ def __init__(
290293 Default: ``os.getcwd()``.
291294 Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
292295
296+ model_registry: The name of the model being uploaded to Model hub.
297+
293298 Raises:
294299 TypeError:
295300 If ``gradient_clip_val`` is not an int or float.
@@ -304,6 +309,9 @@ def __init__(
304309 if default_root_dir is not None :
305310 default_root_dir = os .fspath (default_root_dir )
306311
312+ # remove version if accidentally passed
313+ self ._model_registry = model_registry .split (":" )[0 ] if model_registry else None
314+
307315 self .barebones = barebones
308316 if barebones :
309317 # opt-outs
@@ -519,7 +527,20 @@ def fit(
519527 the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
520528
521529 ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
522- keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
530+ keywords ``"last"``, ``"hpc"`` and ``"registry"``.
531+ Otherwise, if there is no checkpoint file at the path, an exception is raised.
532+
533+ - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
534+ - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
535+ - registry: the model will be downloaded from the Lightning Model Registry with following notations:
536+
537+ - ``'registry'``: uses the latest/default version of default model set
538+ with ``Tainer(..., model_registry="my-model")``
539+ - ``'registry:model-name'``: uses the latest/default version of this model `model-name`
540+ - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
541+ - ``'registry:version:v2'``: uses the default model set
542+ with ``Tainer(..., model_registry="my-model")`` and version 'v2'
543+
523544
524545 Raises:
525546 TypeError:
@@ -567,6 +588,8 @@ def _fit_impl(
567588 )
568589
569590 assert self .state .fn is not None
591+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
592+ download_model_from_registry (ckpt_path , self )
570593 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
571594 self .state .fn ,
572595 ckpt_path ,
@@ -596,8 +619,8 @@ def validate(
596619 Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
597620 the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
598621
599- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
600- If ``None`` and the model instance was passed, use the current weights.
622+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
623+ to validate. If ``None`` and the model instance was passed, use the current weights.
601624 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
602625 if a checkpoint callback is configured.
603626
@@ -675,6 +698,8 @@ def _validate_impl(
675698 self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
676699
677700 assert self .state .fn is not None
701+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
702+ download_model_from_registry (ckpt_path , self )
678703 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
679704 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
680705 )
@@ -705,8 +730,8 @@ def test(
705730 Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
706731 the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
707732
708- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
709- If ``None`` and the model instance was passed, use the current weights.
733+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
734+ to test. If ``None`` and the model instance was passed, use the current weights.
710735 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
711736 if a checkpoint callback is configured.
712737
@@ -784,6 +809,8 @@ def _test_impl(
784809 self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
785810
786811 assert self .state .fn is not None
812+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
813+ download_model_from_registry (ckpt_path , self )
787814 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
788815 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
789816 )
@@ -820,8 +847,8 @@ def predict(
820847 return_predictions: Whether to return predictions.
821848 ``True`` by default except when an accelerator that spawns processes is used (not supported).
822849
823- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
824- If ``None`` and the model instance was passed, use the current weights.
850+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
851+ to predict. If ``None`` and the model instance was passed, use the current weights.
825852 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
826853 if a checkpoint callback is configured.
827854
@@ -893,6 +920,8 @@ def _predict_impl(
893920 self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
894921
895922 assert self .state .fn is not None
923+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
924+ download_model_from_registry (ckpt_path , self )
896925 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
897926 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
898927 )
0 commit comments