diff --git a/deeppavlov/core/common/chainer.py b/deeppavlov/core/common/chainer.py index b3f78d13e3..ba65f7bcb4 100644 --- a/deeppavlov/core/common/chainer.py +++ b/deeppavlov/core/common/chainer.py @@ -14,6 +14,7 @@ import pickle from itertools import islice +from pathlib import Path from logging import getLogger from types import FunctionType from typing import Union, Tuple, List, Optional, Hashable, Reversible @@ -275,10 +276,10 @@ def get_main_component(self) -> Optional[Serializable]: log.warning('Cannot get a main component for an empty chainer') return None - def save(self) -> None: + def save(self, fname: Optional[Union[str, Path]] = None) -> None: main_component = self.get_main_component() if isinstance(main_component, Serializable): - main_component.save() + main_component.save(fname) def load(self) -> None: for in_params, out_params, component in self.train_pipe: diff --git a/deeppavlov/core/models/tf_model.py b/deeppavlov/core/models/tf_model.py index 39d867165b..439e8b7a6b 100644 --- a/deeppavlov/core/models/tf_model.py +++ b/deeppavlov/core/models/tf_model.py @@ -66,16 +66,17 @@ def deserialize(self, weights: Iterable[Tuple[str, np.ndarray]]) -> None: feed_dict[assign_placeholder] = value self.sess.run(assign_ops, feed_dict=feed_dict) - def save(self, exclude_scopes: tuple = ('Optimizer',)) -> None: + def save(self, exclude_scopes: tuple = ('Optimizer',), fname: Optional[Union[str, Path]] = None) -> None: """Save model parameters to self.save_path""" if not hasattr(self, 'sess'): raise RuntimeError('Your TensorFlow model {} must' ' have sess attribute!'.format(self.__class__.__name__)) - path = str(self.save_path.resolve()) - log.info('[saving model to {}]'.format(path)) + if fname is None: + fname = str(self.save_path.resolve()) + log.info('[saving model to {}]'.format(fname)) var_list = self._get_saveable_variables(exclude_scopes) saver = tf.train.Saver(var_list) - saver.save(self.sess, path) + saver.save(self.sess, fname) def serialize(self) -> Tuple[Tuple[str, np.ndarray], ...]: tf_vars = tf.global_variables() diff --git a/deeppavlov/core/models/torch_model.py b/deeppavlov/core/models/torch_model.py index 67bfee27ab..ae9747ba83 100644 --- a/deeppavlov/core/models/torch_model.py +++ b/deeppavlov/core/models/torch_model.py @@ -16,7 +16,7 @@ from copy import deepcopy from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Optional, Union import torch from overrides import overrides @@ -127,7 +127,7 @@ def init_from_opt(self, model_func: str) -> None: raise AttributeError("Model is not defined.") @overrides - def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: + def load(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None: """Load model from `fname` (if `fname` is not given, use `self.load_path`) to `self.model` along with the optimizer `self.optimizer`, optionally `self.lr_scheduler`. If `fname` (if `fname` is not given, use `self.load_path`) does not exist, initialize model from scratch. @@ -143,15 +143,18 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: if fname is not None: self.load_path = fname + if isinstance(self.load_path, str): + self.load_path = Path(self.load_path) + model_func = getattr(self, self.opt.get("model_name"), None) if self.load_path: log.info(f"Load path {self.load_path} is given.") - if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir(): + if not self.load_path.parent.is_dir(): raise ConfigError("Provided load path is incorrect!") - weights_path = Path(self.load_path.resolve()) - weights_path = weights_path.with_suffix(f".pth.tar") + weights_path = self.load_path.resolve() + weights_path = weights_path.with_suffix(".pth.tar") if weights_path.exists(): log.info(f"Load path {weights_path} exists.") log.info(f"Initializing `{self.__class__.__name__}` from saved.") @@ -173,7 +176,7 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: self.init_from_opt(model_func) @overrides - def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: + def save(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None: """Save torch model to `fname` (if `fname` is not given, use `self.save_path`). Checkpoint includes `model_state_dict`, `optimizer_state_dict`, and `epochs_done` (number of training epochs). @@ -187,11 +190,13 @@ def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: """ if fname is None: fname = self.save_path - + else: + fname = str(self.save_path) + fname + fname = Path(fname) if not fname.parent.is_dir(): raise ConfigError("Provided save path is incorrect!") - weights_path = Path(fname).with_suffix(f".pth.tar") + weights_path = fname.with_suffix(f".pth.tar") log.info(f"Saving model to {weights_path}.") # move the model to `cpu` before saving to provide consistency torch.save({ diff --git a/deeppavlov/core/trainers/nn_trainer.py b/deeppavlov/core/trainers/nn_trainer.py index 6f6fd8b4bf..bea7977a03 100644 --- a/deeppavlov/core/trainers/nn_trainer.py +++ b/deeppavlov/core/trainers/nn_trainer.py @@ -25,6 +25,7 @@ from deeppavlov.core.data.data_learning_iterator import DataLearningIterator from deeppavlov.core.trainers.fit_trainer import FitTrainer from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder +from deeppavlov.core.models.serializable import Serializable log = getLogger(__name__) @@ -72,6 +73,8 @@ class NNTrainer(FitTrainer): log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``) max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``, ignored if negative (default is ``-1``) + save_every_n_batches: how often (in batches) to save model into f'{save_path}_{current_step}, the best model + is still saved to `save_path`, ignored if negative or zero (default is ``-1``) **kwargs: additional parameters whose names will be logged but otherwise ignored @@ -103,6 +106,7 @@ def __init__(self, chainer_config: dict, *, validate_first: bool = True, validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1, log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1, + save_every_n_batches: int = -1, **kwargs) -> None: super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets, show_examples=show_examples, tensorboard_log_dir=tensorboard_log_dir, @@ -134,6 +138,7 @@ def _improved(op): self.log_every_n_epochs = log_every_n_epochs self.log_every_n_batches = log_every_n_batches self.log_on_k_batches = log_on_k_batches if log_on_k_batches >= 0 else None + self.save_every_n_batches = save_every_n_batches self.max_epochs = epochs self.epoch = start_epoch_num @@ -150,11 +155,10 @@ def _improved(op): self.tb_train_writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / 'train_log')) self.tb_valid_writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / 'valid_log')) - def save(self) -> None: + def save(self, fname: Optional[Union[str, Path]] = None) -> None: if self._loaded: raise RuntimeError('Cannot save already finalized chainer') - - self._chainer.save() + self._chainer.save(fname) def _is_initial_validation(self): return self.validation_number == 0 @@ -297,6 +301,10 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None: if self.val_every_n_batches > 0 and self.train_batches_seen % self.val_every_n_batches == 0: self._validate(iterator, tensorboard_tag='every_n_batches', tensorboard_index=self.train_batches_seen) + + if self.save_every_n_batches > 0 and self.train_batches_seen % self.save_every_n_batches == 0: + log.info(f'Saving model at step: {self.train_batches_seen}') + self.save(fname = f'_{self.train_batches_seen}' ) self._send_event(event_name='after_batch') diff --git a/deeppavlov/models/bert/bert_sequence_tagger.py b/deeppavlov/models/bert/bert_sequence_tagger.py index 4d3d8ad75f..4e2c172a22 100644 --- a/deeppavlov/models/bert/bert_sequence_tagger.py +++ b/deeppavlov/models/bert/bert_sequence_tagger.py @@ -14,6 +14,7 @@ from logging import getLogger from typing import List, Union, Dict, Optional +from pathlib import Path import numpy as np import tensorflow as tf @@ -443,10 +444,10 @@ def __call__(self, **kwargs) -> Union[List[List[int]], List[np.ndarray]]: raise NotImplementedError("You must implement method __call__ in your derived class.") - def save(self, exclude_scopes=('Optimizer', 'EMA/BackupVariables')) -> None: + def save(self, exclude_scopes=('Optimizer', 'EMA/BackupVariables'), fname: Optional[Union[str, Path]] = None) -> None: if self.ema: self.sess.run(self.ema.switch_to_train_op) - return super().save(exclude_scopes=exclude_scopes) + return super().save(exclude_scopes=exclude_scopes, fname = fname) def load(self, exclude_scopes=('Optimizer',