diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index 88dcba29..8d3b37cb 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -103,11 +103,17 @@ def setup_and_train( training_modules: TrainingModules, database, setup_params: SetupParams, + store_all_better=False, + store_best=True, + store_every=0 ): """ :param: training_modules: see :func:`setup_training` :param: database: see :func:`train_model` :param: setup_params: see :func:`setup_training` + :param: store_all_better: Save the state dict for each model doing better than a previous one + :param: store_best: Save a checkpoint for the best model + :param: store_every: Save a checkpoint for every certain epochs :return: See :func:`train_model` Shortcut for setup_training followed by train_model. @@ -134,6 +140,9 @@ def setup_and_train( metric_tracker=metric_tracker, callbacks=None, batch_callbacks=None, + store_all_better=store_all_better, + store_best=store_best, + store_every=store_every ) @@ -212,6 +221,7 @@ def train_model( batch_callbacks, store_all_better=False, store_best=True, + store_every=0, store_structure_file=True, store_metrics=True, quiet=False, @@ -228,6 +238,7 @@ def train_model( :param batch_callbacks: callbacks to perform after every batch :param store_best: Save a checkpoint for the best model :param store_all_better: Save the state dict for each model doing better than a previous one + :param store_every: Save a checkpoint for every certain epochs :param store_structure_file: Save the structure file for this experiment :param store_metrics: Save the metric tracker for this experiment. :param quiet: If True, disable printing during training (still prints testing results). @@ -286,6 +297,7 @@ def train_model( batch_callbacks=batch_callbacks, store_best=store_best, store_all_better=store_all_better, + store_every=store_every, quiet=quiet, ) @@ -364,6 +376,7 @@ def training_loop( batch_callbacks, store_all_better, store_best, + store_every, quiet, ): """ @@ -377,6 +390,7 @@ def training_loop( :param batch_callbacks: list of callbacks for each batch :param store_best: Save a checkpoint for the best model :param store_all_better: Save the state dict for each model doing better than a previous one + :param store_every: Save a checkpoint for every certain epochs :param quiet: whether to print information. Setting quiet to true won't prevent progress bars. :return: metrics -- the state of the experiment after training @@ -506,6 +520,17 @@ def training_loop( # Write the checkpoint with open("best_checkpoint.pt", "wb") as pfile: torch.save(state, pfile) + + if store_every and epoch != 0 and (epoch % store_every) == 0: + # Save a copy every "store_every" epoch + with open(f"model_epoch_{epoch}.pt", "wb") as pfile: + torch.save(model.state_dict(), pfile) + + state = serialization.create_state(model, controller, metric_tracker) + + # Write the checkpoint + with open(f"checkpoint_epoch_{epoch}.pt", "wb") as pfile: + torch.save(state, pfile) epoch += 1