Source code for inferno.trainers.callbacks.essentials

import numpy as np
import os
import h5py as h5
from ...utils import torch_utils as tu
from ...utils.train_utils import Frequency
from ...utils.exceptions import assert_, FrequencyValueError, NotUnwrappableError
from ...utils import python_utils as pyu
from .base import Callback


[docs]class NaNDetector(Callback):
[docs] def end_of_training_iteration(self, **_): training_loss = self.trainer.get_state('training_loss') # Extract scalar if tu.is_tensor(training_loss): training_loss = training_loss.float()[0] if not np.isfinite(training_loss): raise RuntimeError("Loss is not finite (loss={})!".format(training_loss))
[docs]class PersistentSave(Callback): def __init__(self, template='checkpoint.pytorch.epoch{epoch_count}.iteration{iteration_count}'): super(PersistentSave, self).__init__() self.template = template
[docs] def begin_of_save(self, **kwargs): self._orig_checkpoint_filename = self.trainer._checkpoint_filename self.trainer._checkpoint_filename = self.template.format(**kwargs)
[docs] def end_of_save(self, save_to_directory, **_): orig_checkpoint_path = os.path.join(save_to_directory, self._orig_checkpoint_filename) if os.path.lexists(orig_checkpoint_path): os.remove(orig_checkpoint_path) os.symlink(self.trainer._checkpoint_filename, orig_checkpoint_path) self.trainer._checkpoint_filename = self._orig_checkpoint_filename
[docs]class DumpHDF5Every(Callback): """Dumps intermediate training states to a HDF5 file.""" def __init__(self, frequency, to_directory, filename_template='dump.{mode}.epoch{epoch_count}.iteration{iteration_count}.h5', force_dump=False, dump_after_every_validation_run=False): super(DumpHDF5Every, self).__init__() # Privates self._dump_every = None self._trainer_states_to_be_dumped_while_training = {'training_inputs', 'training_target', 'training_prediction'} self._trainer_states_to_be_dumped_while_validating = {'validation_inputs', 'validation_target', 'validation_prediction'} self._dump_cache = {} # Publics self.dump_every = frequency self.dump_directory = to_directory self.dump_filename_template = filename_template self.force_dump = force_dump # hihi self.dump_after_every_validation_run = dump_after_every_validation_run @property def dump_every(self): return self._dump_every @dump_every.setter def dump_every(self, value): self._dump_every = Frequency.build_from(value) assert_(self._dump_every.is_consistent, "Dump frequency is not consistent.", FrequencyValueError) @property def dump_now(self): return self.dump_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True, match_zero=True)
[docs] def add_to_dump_cache(self, key, value): if pyu.is_listlike(value): for value_num, _value in enumerate(value): self.add_to_dump_cache("{}_{}".format(key, value_num), _value) else: self._dump_cache.update({key: value})
[docs] def clear_dump_cache(self): self._dump_cache.clear()
[docs] def dump_state(self, key, dump_while='training'): # Validate arguments keyword_mapping = {'train': 'training', 'training': 'training', 'validation': 'validating', 'validating': 'validating'} dump_while = keyword_mapping.get(dump_while) assert_(dump_while is not None, "The keyword dump_while must be one of: {}." .format(set(keyword_mapping.keys())), ValueError) assert_(isinstance(key, str), "State key must be a string, got {} instead.".format(type(key).__name__), TypeError) # Add to set of observed states if dump_while == 'training': self._trainer_states_to_be_dumped_while_training.add(key) elif dump_while == 'validating': self._trainer_states_to_be_dumped_while_validating.add(key) else: raise NotImplementedError return self
[docs] def dump_states(self, keys, dump_while='training'): for key in keys: self.dump_state(key, dump_while=dump_while) return self
[docs] def get_file_path(self, mode): # Make sure the dump directory exists if not os.path.exists(self.dump_directory): os.mkdir(self.dump_directory) else: assert_(os.path.isdir(self.dump_directory), "Dump directory {} is a file.".format(self.dump_directory), FileExistsError) filename = self.dump_filename_template.format(epoch_count=self.trainer.epoch_count, iteration_count=self.trainer.iteration_count, mode=mode) return os.path.join(self.dump_directory, filename)
[docs] def dump(self, mode): with h5.File(name=self.get_file_path(mode), mode='w') as h5_file: for key, to_dump in self._dump_cache.items(): if to_dump is None: continue try: to_dump = tu.unwrap(to_dump, as_numpy=True) except NotUnwrappableError: # Can't unwrap to_dump, but let's not throw a tantrum if we're not required to if not self.force_dump: continue else: raise # Do the dumpin' h5_file.create_dataset(name=key, data=to_dump)
[docs] def end_of_training_iteration(self, **_): dump_now = self.dump_now if dump_now: # To be double sure self.clear_dump_cache() # Get object to dump for state_name in self._trainer_states_to_be_dumped_while_training: self.add_to_dump_cache(state_name, self.trainer.get_state(state_name)) # Dump self.dump(mode='training') # Clear cache self.clear_dump_cache()
[docs] def end_of_validation_run(self, **_): if self.dump_after_every_validation_run: # To be double sure self.clear_dump_cache() # Get object to dump for state_name in self._trainer_states_to_be_dumped_while_validating: self.add_to_dump_cache(state_name, self.trainer.get_state(state_name)) # Dump self.dump(mode='validation') # Clear cache self.clear_dump_cache()
[docs]class SaveAtBestValidationScore(Callback): """ Triggers a save at the best EMA (exponential moving average) validation score. The basic `Trainer` has built in support for saving at the best validation score, but this callback might eventually replace that functionality. """ def __init__(self, smoothness=0, verbose=False): super(SaveAtBestValidationScore, self).__init__() # Privates self._ema_validation_score = None self._best_ema_validation_score = None # Publics self.smoothness = smoothness self.verbose = verbose
[docs] def end_of_validation_run(self, **_): # Get score (i.e. validation error if available, else validation loss) current_validation_score = self.trainer.get_state('validation_error_averaged') current_validation_score = self.trainer.get_state('validation_loss_averaged') \ if current_validation_score is None else current_validation_score # Maintain ema if self._ema_validation_score is None: self._ema_validation_score = current_validation_score self._best_ema_validation_score = current_validation_score else: self._ema_validation_score = self.smoothness * self._ema_validation_score + \ (1 - self.smoothness) * current_validation_score # This overrides the default behaviour, but reduces to it if smoothness = 0 self.trainer._is_iteration_with_best_validation_score = \ self._ema_validation_score < self._best_ema_validation_score # Trigger a save if self.trainer._is_iteration_with_best_validation_score: if self.verbose: self.trainer.console.info("Current smoothed validation score {} is better " "than the best smoothed validation score {}." .format(self._ema_validation_score, self._best_ema_validation_score)) self._best_ema_validation_score = self._ema_validation_score self.trainer.save_now = True else: if self.verbose: self.trainer.console.info("Current smoothed validation score {} is not better " "than the best smoothed validation score {}." .format(self._ema_validation_score, self._best_ema_validation_score))
# Done
[docs]class ParameterEMA(Callback): """Maintain a moving average of network parameters.""" def __init__(self, momentum): """ Parameters ---------- momentum : float Momentum for the moving average. The following holds: `new_moving_average = momentum * old_moving_average + (1 - momentum) * value` """ super(ParameterEMA, self).__init__() # Privates self._parameters = None # Publics self.momentum = momentum
[docs] def maintain(self): if self._parameters is None: self._parameters = [p.data.new().zero_() for p in self.trainer.model.parameters()] for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters): p_ema.mul_(self.momentum).add_(p_model.data.mul(1. - self.momentum))
[docs] def apply(self): assert_(self._parameters is not None, "Can't apply parameter EMA's: not available.", ValueError) for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters): p_model.data.copy_(p_ema)
[docs] def end_of_training_iteration(self, **_): self.maintain()