Source code for inferno.trainers.basic

import dill
from datetime import datetime
from inspect import signature
import os
import shutil

import torch
from numpy import inf
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.nn.parallel.data_parallel import data_parallel
from .callbacks.logging.base import Logger
from .callbacks.logging import get_logger

from ..utils import train_utils as tu
from ..utils import python_utils as pyu
from ..utils import torch_utils as thu
from ..extensions import metrics
from ..extensions import optimizers
from ..extensions import criteria
from .callbacks import CallbackEngine
from .callbacks import Console
from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError


[docs]class Trainer(object): """A basic trainer. Given a torch model, this class encapsulates the training and validation loops, checkpoint creation, logging, CPU <-> GPU transfers and managing data-loaders. In addition, this class interacts with the callback engine (found at `inferno.trainers.callbacks.base.CallbackEngine`), which manages callbacks at certain preset events. Notes ----- Logging is implemented as a special callback, in the sense that it's jointly managed by the this class and the callback engine. This is primarily because general callbacks are not intended to be serializable, but not being able to serialize the logger is a nuisance. """ def __init__(self, model=None): """ Parameters ---------- model : torch.nn.Module Torch model to bind to. """ # Privates # Core self._model = None self._optimizer = None self._criterion = None # Metric evaluation self._metric = None self._evaluate_metric_every = None self._metric_evaluation_externally_triggered = False self._last_metric_evaluated_at_epoch = 0 # Logging self._logger = None self._last_logged = {} self._log_directory = {} # Data logistics self._loaders = {} self._loader_iters = {} self._loader_specs = {} # Iteration and epoch book-keeping self._iteration_count = 0 self._epoch_count = 0 self._batch_count = 0 # GPU and dtype business self._use_cuda = False self._dtype = 'float' self._devices = None self._base_device_ordinal = None # Validation self._save_at_best_validation_score = False self._best_validation_score = None self._is_iteration_with_best_validation_score = False self._validate_every = None self._num_validation_iterations = None self._target_batch_dim = 0 # We should exclude the zero-th epoch from validation self._last_validated_at_epoch = 0 self._last_validated_at_iteration = 0 # This is to allow a callback to trigger a validation by setting # trainer.validate_now = True self._validation_externally_triggered = False # Checkpointing self._save_every = None self._save_to_directory = None # Defaults for file names self._checkpoint_filename = 'checkpoint.pytorch' self._best_checkpoint_filename = 'best_checkpoint.pytorch' # Nothing to save at epoch 0 self._last_saved_at_epoch = 0 # This is to allow a callback to trigger a save by setting trainer.save_now = True self._save_externally_triggered = False # Stopping conditions self._max_num_iterations = None self._max_num_epochs = None # Callbacks and states self._callback_engine = CallbackEngine().bind_trainer(self) self._state = {} # Print console self._console = Console() # Public if model is not None: self.model = model @property def console(self): """Get the current console.""" return self._console @property def callbacks(self): """Gets the callback engine.""" return self._callback_engine
[docs] def register_callback(self, callback, trigger='auto', **callback_kwargs): """ Registers a callback with the internal callback engine. Parameters ---------- callback : type or callable Callback to register. trigger : str Specify the event that triggers the callback. Leave at 'auto' to have the callback-engine figure out the triggers. See `inferno.training.callbacks.base.CallbackEngine` documentation for more on this. callback_kwargs : dict If `callback` is a type, initialize an instance with these keywords to the __init__ method. Returns ------- Trainer self. """ if isinstance(callback, type): callback = callback(**callback_kwargs) self._callback_engine.register_callback(callback, trigger=trigger) return self
@property def model(self): """Gets the model.""" assert_(self._model is not None, "Model is not defined yet.", NotSetError) return self._model @model.setter def model(self, value): self.bind_model(value)
[docs] def bind_model(self, model): """ Binds a model to the trainer. Equivalent to setting model. Parameters ---------- model : torch.nn.Module Model to bind. Returns ------- Trainer self. """ assert_(isinstance(model, torch.nn.Module), "Model must be a torch.nn.Module.", NotTorchModuleError) self._model = model # Transfer model to GPU if required if self._use_cuda: self._model.cuda() return self
@property def model_is_defined(self): return self._model is not None @property def optimizer(self): """Gets the optimizer.""" assert_(self._optimizer is not None, "Optimizer is not set yet.", NotSetError) return self._optimizer @optimizer.setter def optimizer(self, value): if isinstance(value, str) or callable(value): self.build_optimizer(value) elif isinstance(value, dict): self.build_optimizer(**value) else: raise NotImplementedError @property def optimizer_is_defined(self): return self._optimizer is not None
[docs] def build_optimizer(self, method, param_groups=None, **kwargs): """ Builds the optimizer for training. Parameters ---------- method : str or callable or torch.optim.Optimizer Name of the optimizer when str, handle to the optimizer class when callable, or a torch.optim.Optimizer instance. If a name is provided, this method looks for the optimizer in `torch.optim` module first and in inferno.extensions.optimizers second. param_groups : list of dict Specifies the parameter group. Defaults to model.parameters() if None. kwargs : dict Keyword arguments to the optimizer. Returns ------- Trainer self. Raises ------ AssertionError if optimizer is not found NotImplementedError if method is not str or callable. """ if isinstance(method, str): optimizer_class = getattr(torch.optim, method, None) if optimizer_class is None: # Look for optimizer in extensions optimizer_class = getattr(optimizers, method, None) assert optimizer_class is not None, "Optimizer {} not found.".format(method) elif callable(method) and isinstance(method, type): optimizer_class = method elif isinstance(method, torch.optim.Optimizer): self._optimizer = method return self else: raise NotImplementedError param_groups = self.model.parameters() if param_groups is None else param_groups self._optimizer = optimizer_class(param_groups, **kwargs) return self
@property def criterion(self): """Gets the loss criterion.""" assert_(self._criterion is not None, "Criterion is not set yet.", NotSetError) return self._criterion @criterion.setter def criterion(self, value): if isinstance(value, str) or callable(value): self.build_criterion(value) elif isinstance(value, dict): self.build_criterion(**value) else: raise NotImplementedError
[docs] def build_criterion(self, method, **kwargs): """ Builds the loss criterion for training. Parameters ---------- method : str or callable or torch.nn.Module Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in `torch.nn`. kwargs : dict Keyword arguments to the criterion class' constructor if applicable. Returns ------- Trainer self. Raises ------ AssertionError if criterion is not found. NotImplementedError if method is neither a str nor a callable. """ if isinstance(method, str): # Look for criteria in torch criterion_class = getattr(torch.nn, method, None) if criterion_class is None: # Look for it in extensions criterion_class = getattr(criteria, method, None) assert criterion_class is not None, "Criterion {} not found.".format(method) elif callable(method) and isinstance(method, type): criterion_class = method elif isinstance(method, torch.nn.Module): self._criterion = method return self else: raise NotImplementedError self._criterion = criterion_class(**kwargs) # Transfer criterion to GPU if required. This is necessary for e.g. weighted loss, # where the weight is registered as a buffer. # The criterion is to be cuda'ed only if the model is on CUDA (self._use_cuda) and # the base_device is not CPU (ordinal -1). if hasattr(self, '_base_device_ordinal'): # This is to not break old checkpoints base_device_ordinal = self._base_device_ordinal else: base_device_ordinal = None if self._use_cuda and base_device_ordinal != 1: self._criterion.cuda() return self
@property def criterion_is_defined(self): return self._criterion is not None @property def metric(self): """Gets the evaluation metric.""" assert_(self._metric is not None, "Metric is not set yet.", NotSetError) return self._metric @metric.setter def metric(self, value): if callable(value) or isinstance(value, str): self.build_metric(value) else: raise NotImplementedError @property def evaluating_metric_every(self): return self._evaluate_metric_every
[docs] def evaluate_metric_every(self, frequency): """ Set frequency of metric evaluation __during training__ (and not during validation). Parameters ---------- frequency : inferno.utils.train_utils.Frequency or str or tuple or list or int Metric evaluation frequency. If str, it could be (say) '10 iterations' or '1 epoch'. If tuple (or list), it could be (10, 'iterations') or (1, 'epoch'). If int (say 10), it's interpreted as (10, 'iterations'). Returns ------- Trainer self """ self._evaluate_metric_every = tu.Frequency.build_from(frequency, priority='iterations') assert self._evaluate_metric_every.is_consistent return self
@property def evaluate_metric_now(self): if self._metric_evaluation_externally_triggered: # Reset trigger self._metric_evaluation_externally_triggered = False return True elif self._evaluate_metric_every is None: # By default, evaluate metric every time return True elif self._evaluate_metric_every is not None and self._evaluate_metric_every.by_epoch: # Don't evaluate if we've done so already this epoch if self._last_metric_evaluated_at_epoch == self._epoch_count: return False else: # If we haven't evaluated this epoch, check if we should return self._evaluate_metric_every.match(epoch_count=self._epoch_count) else: # This is reached when evaluate_metric_every is defined and matching by # iteration count return self._evaluate_metric_every.match(iteration_count=self._iteration_count) @evaluate_metric_now.setter def evaluate_metric_now(self, value): self._metric_evaluation_externally_triggered = bool(value)
[docs] def build_metric(self, method, **kwargs): """ Builds the metric for evaluation. Parameters ---------- method : callable or str Name of the metric when string, metric class or a callable object when callable. If a name is provided, this method looks for the metric in `inferno.extensions.metrics`. kwargs : dict Keyword arguments to the metric class' constructor, if applicable. Returns ------- Trainer self. Raises ------ AssertionError: if the metric is not found. """ if callable(method): if isinstance(method, type): self._metric = method(**kwargs) else: self._metric = method elif isinstance(method, str): assert hasattr(metrics, method), \ "Could not find the metric '{}'.".format(method) self._metric = getattr(metrics, method)() else: raise NotImplementedError return self
@property def metric_is_defined(self): """Checks if the metric is defined.""" return self._metric is not None
[docs] def eval_mode(self): """Set model, criterion and metric to eval mode""" self.model.eval() if self.criterion_is_defined and isinstance(self.criterion, torch.nn.Module): self.criterion.eval() if self.metric_is_defined and isinstance(self.metric, torch.nn.Module): self.metric.eval() return self
[docs] def train_mode(self): """Set model, criterion and metric to train mode""" self.model.train() if self.criterion_is_defined and isinstance(self.criterion, torch.nn.Module): self.criterion.train() if self.metric_is_defined and isinstance(self.metric, torch.nn.Module): self.metric.train() return self
@property def train_loader(self): assert self._loaders.get('train') is not None return self._loaders.get('train') @train_loader.setter def train_loader(self, value): assert isinstance(value, DataLoader) self._loaders.update({'train': value}) @property def validate_loader(self): assert self._loaders.get('validate') is not None return self._loaders.get('validate') @validate_loader.setter def validate_loader(self, value): assert isinstance(value, DataLoader) self._loaders.update({'validate': value}) @property def logger(self): """Gets the logger.""" return self._logger @logger.setter def logger(self, value): if isinstance(value, dict): self.build_logger(**value) else: self.build_logger(logger=value) @property def log_directory(self): """Gets the log directory.""" return self._log_directory @log_directory.setter def log_directory(self, value): """Sets the log directory,""" if value is not None: self.set_log_directory(value) @property def saving_every(self): """Gets the frequency at which checkpoints are made.""" return self._save_every
[docs] def save_at_best_validation_score(self, yes=True): """Sets whether to save when the validation score is the best seen.""" self._save_at_best_validation_score = yes return self
@property def save_now(self): if self._save_externally_triggered: # Reset trigger self._save_externally_triggered = False # Save if externally triggered return True elif self._save_at_best_validation_score and self._is_iteration_with_best_validation_score: return True else: # Check if we're saving by epoch if self._save_every is not None and self._save_every.by_epoch: # Don't save if we've already saved once this epoch if self._epoch_count == self._last_saved_at_epoch: return False else: # If we haven't saved this epoch, check if we should return self._save_every.match(epoch_count=self._epoch_count) else: # We're saving by iterations return self._save_every is not None and \ self._save_every.match(iteration_count=self._iteration_count) @save_now.setter def save_now(self, value): """Can be set to true to trigger a checkpoint creation..""" self._save_externally_triggered = bool(value)
[docs] def save_every(self, frequency, to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None): """ Set checkpoint creation frequency. Parameters ---------- frequency : inferno.utils.train_utils.Frequency or tuple or str Checkpoint creation frequency. Examples: '100 iterations' or '1 epochs'. to_directory : str Directory where the checkpoints are to be created. checkpoint_filename : str Name of the checkpoint file. best_checkpoint_filename : str Name of the best checkpoint file. Returns ------- Trainer self. """ self._save_every = tu.Frequency.build_from(frequency, priority='iterations') assert self._save_every.is_consistent self.save_to_directory(to_directory, checkpoint_filename, best_checkpoint_filename) return self
@property def save_directory(self): return self._save_to_directory
[docs] def save_to_directory(self, to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None): if to_directory is not None: assert_(isinstance(to_directory, str), exception_type=TypeError) if not os.path.exists(to_directory): os.makedirs(to_directory) else: assert os.path.isdir(to_directory) self._save_to_directory = to_directory if checkpoint_filename is not None: assert_(isinstance(checkpoint_filename, str), exception_type=TypeError) self._checkpoint_filename = checkpoint_filename if best_checkpoint_filename is not None: assert_(isinstance(best_checkpoint_filename, str), exception_type=TypeError) self._best_checkpoint_filename = best_checkpoint_filename return self
@property def validating_every(self): return self._validate_every @property def validate_now(self): if self._validation_externally_triggered: # Reset trigger self._validation_externally_triggered = False return True elif self._validate_every is not None and self._validate_every.by_epoch: # Don't validate if we've done so already this epoch if self._last_validated_at_epoch == self._epoch_count: return False else: # If we haven't validated this epoch, check if we should return self._validate_every.match(epoch_count=self._epoch_count, match_zero=False) else: # Don't validate if we've done once already this iteration if self._last_validated_at_iteration == self._iteration_count: return False else: # If we haven't validated this iteration, check if we should. The `match_zero` is # redundant, but we'll leave it on anyway. return self._validate_every is not None and \ self._validate_every.match(iteration_count=self._iteration_count, match_zero=False) @validate_now.setter def validate_now(self, value): self._validation_externally_triggered = bool(value)
[docs] def validate_every(self, frequency, for_num_iterations=None): """ Set validation frequency. Parameters ---------- frequency : inferno.utils.train_utils.Frequency or str or tuple or list or int Validation frequency. If str, it could be (say) '10 iterations' or '1 epoch'. If tuple (or list), it could be (10, 'iterations') or (1, 'epoch'). If int (say 10), it's interpreted as (10, 'iterations'). for_num_iterations : int Number of iterations to validate for. If not set, the model is validated on the entire dataset (i.e. till the data loader is exhausted). Returns ------- Trainer self """ self._validate_every = tu.Frequency.build_from(frequency, priority='iterations') assert self._validate_every.is_consistent self._num_validation_iterations = for_num_iterations return self
@property def iteration_count(self): return self._iteration_count @property def epoch_count(self): return self._epoch_count @property def target_batch_dim(self): return self._target_batch_dim @target_batch_dim.setter def target_batch_dim(self, value): assert_(value in [0, 1], "target_batch_dim must be either 0 or 1, got {value} instead.".format(value=value), ValueError) self._target_batch_dim = value
[docs] def set_target_batch_dim(self, value): self.target_batch_dim = value return self
[docs] def build_logger(self, logger=None, log_directory=None, **kwargs): """ Build the logger. Parameters ---------- logger : inferno.trainers.callbacks.logging.base.Logger or str or type Must either be a Logger object or the name of a logger or the class of a logger. log_directory : str Path to the directory where the log files are to be stored. kwargs : dict Keyword arguments to the logger class. Returns ------- Trainer self """ if isinstance(logger, Logger): # Set logger and register with the callback engine. self._logger = logger self.callbacks.register_callback(self._logger) elif callable(logger): kwargs.update({'log_directory': log_directory}) self._logger = logger(**kwargs) self.callbacks.register_callback(self._logger) elif isinstance(logger, str): self._logger = get_logger(logger)(**kwargs) self.callbacks.register_callback(self._logger) elif logger is None: pass else: raise NotImplementedError if log_directory is not None: self.set_log_directory(log_directory) return self
[docs] def set_log_directory(self, log_directory): """ Set the directory where the log files are to be stored. Parameters ---------- log_directory : str Directory where the log files are to be stored. Returns ------- Trainer self """ self._log_directory = log_directory if self._logger is not None: self._logger.set_log_directory(log_directory) return self
# States that are fetched dynamically from the trainer object via properties are # dynamic states. Such states can not be updated. # The following dictionary maps state keys to the corresponding trainer attribute DYNAMIC_STATES = {'learning_rate': 'current_learning_rate'}
[docs] def update_state(self, key, value): assert key not in self.DYNAMIC_STATES, \ "State at key '{}' cannot be updated because it's dynamic.".format(key) self._state.update({key: value}) return self
[docs] def update_state_from_dictionary(self, dictionary): # Unwrap variables (or tensors) self._state.update({ state_key: thu.unwrap(state) for state_key, state in dictionary.items()})
[docs] def update_state_from_model_state_hooks(self): if hasattr(self.model, '_state_hooks'): state_hooks = getattr(self.model, '_state_hooks') if isinstance(state_hooks, dict): self.update_state_from_dictionary(state_hooks)
[docs] def get_state(self, key, default=None): if key in self.DYNAMIC_STATES: return getattr(self, self.DYNAMIC_STATES.get(key), default) else: return self._state.get(key, default)
@property def current_learning_rate(self): return self.get_current_learning_rate()
[docs] def get_current_learning_rate(self): """ Gets the current learning rate. Returns ------- list or float List of learning rates if there are multiple parameter groups, or a float if there's just one. """ learning_rate = [param_group.get('lr', -1.) for param_group in self.optimizer.param_groups] learning_rate = [_learning_rate[0] if thu.is_tensor(_learning_rate) else _learning_rate for _learning_rate in learning_rate] return pyu.from_iterable(learning_rate)
[docs] def cuda(self, devices=None, base_device=None): """ Train on the GPU. Parameters ---------- devices : list Specify the ordinals of the devices to use for dataparallel training. base_device : {'cpu', 'cuda'} When using data-parallel training, specify where the result tensors are collected. If 'cuda', the results are collected in `devices[0]`. Returns ------- Trainer self """ # Validate base_device assert_(base_device in [None, 'cpu', 'cuda'], "`base_device` must either be 'cpu' or 'cuda', got {} instead." .format(base_device), DeviceError) if isinstance(devices, int) or (isinstance(devices, (list, tuple)) and len(devices) == 1): # No data-parallelism, make sure base_device is not CPU assert_(base_device != 'cpu', "Without dataparallelism, `base_device` cannot be 'cpu'.", DeviceError) self._base_device_ordinal = {None: None, 'cpu': -1, 'cuda': None}.get(base_device) # Move model to CUDA if self.model_is_defined: self.model.cuda() # Move criterion to cuda if base device ordinal is not -1 (i.e. CPU) # (the criterion is evaluated on the base device) if self.criterion_is_defined and self._base_device_ordinal != -1: self.criterion.cuda() elif self.criterion_is_defined and self._base_device_ordinal == -1: # Criterion is evaluated on the CPU, make sure that's where it lives self.criterion.cpu() self._use_cuda = True self._devices = devices return self
[docs] def cpu(self): """ Train on the CPU. Returns ------- Trainer self """ if self.model_is_defined: self.model.cpu() if self.criterion_is_defined: self.criterion.cpu() self._use_cuda = False self._devices = None return self
[docs] def is_cuda(self): """Returns whether using GPU for training.""" return self._use_cuda
[docs] def to_device(self, objects): if isinstance(objects, (list, tuple)): return type(objects)([self.to_device(_object) for _object in objects]) else: return objects.cuda() if self._use_cuda else objects
[docs] def apply_model(self, *inputs): if hasattr(self, '_base_device_ordinal'): # This is to not break old checkpoints base_device_ordinal = self._base_device_ordinal else: base_device_ordinal = None if self._devices is not None: return data_parallel(self.model, inputs, list(self._devices), output_device=base_device_ordinal) else: return self.model(*inputs)
[docs] def cast(self, objects): if isinstance(objects, (list, tuple)): return type(objects)([self.cast(_object) for _object in objects]) else: # Cast only the float types, while leaving the ints alone if objects.__class__.__name__ in ['HalfTensor', 'FloatTensor', 'DoubleTensor']: cast_fn = getattr(objects, self._dtype, None) else: cast_fn = None if cast_fn is not None: return cast_fn() else: return objects
[docs] def set_precision(self, dtype): """ Set training precision. Parameters ---------- dtype : {'double', 'float', 'half'} Training precision. Returns ------- Trainer self """ assert dtype in ['double', 'float', 'half'] self._dtype = dtype self._model = getattr(self._model, dtype)() return self
@property def dtype(self): return self._dtype @dtype.setter def dtype(self, value): self.set_precision(value)
[docs] def bind_loader(self, name, loader, num_inputs=None, num_targets=1): """ Bind a data loader to the trainer. Parameters ---------- name : {'train', 'validate', 'test'} Name of the loader, i.e. what it should be used for. loader : torch.utils.data.DataLoader DataLoader object. num_inputs : int Number of input tensors from the `loader`. num_targets : int Number of target tensors from the `loader`. Returns ------- Trainer self Raises ------ KeyError if name is invalid. TypeError if loader is not a DataLoader instance. """ assert_(name in ['train', 'validate', 'test'], "`name` must be one of ['train', 'validate', 'test']. " "Got {} instead.".format(name), KeyError) assert_(isinstance(loader, DataLoader), "`loader` must be a DataLoader object. " "Got {} instead.".format(type(loader).__name__), TypeError) # Check to see if the loader is actually new. This should usually be True. is_new_loader = loader is not self._loaders.get(name) self._loaders.update({name: loader}) # We also need to account for the case when a loader is being replaced. When this happens, # the old DataLoaderIter might still have processes running, which we need to kill. if is_new_loader and name in self._loader_iters: # This is when the previous loader already has a DataLoaderIter running. # The DataLoaderIter implements a __del__ method, which shuts down workers. del self._loader_iters[name] # Trainers loaded from pickle files might not have '_loader_specs', therefore: if not hasattr(self, '_loader_specs'): setattr(self, '_loader_specs', {}) self._loader_specs.update({name: {'num_inputs': num_inputs, 'num_targets': num_targets}}) return self
[docs] def get_loader_specs(self, name): assert name in self._loader_specs.keys(), \ "Could not find specs about loader '{}'. Valid loader names are: {}" \ .format(name, set(self._loader_specs.keys())) return self._loader_specs.get(name)
[docs] def fetch_next_batch(self, from_loader='train', restart_exhausted_generators=True, update_batch_count=True, update_epoch_count_if_generator_exhausted=True): # Check if the iterator is built if from_loader not in self._loader_iters: self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()}) # Try to fetch from iterator try: # Fetch next_batch = next(self._loader_iters[from_loader]) # Verify self.verify_batch(next_batch, from_loader) if update_batch_count: self._batch_count += 1 return next_batch except StopIteration: # This if clause prevents infinite recursion if the loader is empty if restart_exhausted_generators: self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()}) # Update epoch count if update_epoch_count_if_generator_exhausted: self.next_epoch() return self.fetch_next_batch(from_loader, restart_exhausted_generators=False, update_batch_count=update_batch_count) else: raise
[docs] def verify_batch(self, batch, from_loader): loader_specs = self.get_loader_specs(from_loader) num_inputs = loader_specs.get('num_inputs') num_targets = loader_specs.get('num_targets') if None not in [num_inputs, num_targets]: assert len(batch) == num_inputs + num_targets, \ "Was expecting a batch with {} (= num_inputs) + {} (= num_targets) tensors, " \ "got one with {} tensors.".format(num_inputs, num_targets, len(batch)) if num_inputs is not None: assert len(batch) > num_inputs, \ "Expecting {} inputs, but the batch contains only {} tensors." \ .format(num_inputs, len(batch)) if num_targets is not None: assert len(batch) > num_targets, \ "Expecting {} outputs, but the batch contains only {} tensors." \ .format(num_targets, len(batch)) return batch
[docs] def split_batch(self, batch, from_loader): loader_specs = self.get_loader_specs(from_loader) num_inputs = loader_specs.get('num_inputs') num_targets = loader_specs.get('num_targets') assert not (num_targets is None and num_inputs is None), \ "Can not split batch if both the number of inputs and targets is not known." if num_inputs is None: # Unknown number of inputs inputs, targets = batch[:-num_targets], batch[-num_targets:] elif num_targets is None: # Unknown number of targets inputs, targets = batch[:num_inputs], batch[num_inputs:] else: # Known number of inputs and targets inputs, targets = batch[:num_inputs], batch[-num_targets:] return inputs, pyu.from_iterable(targets)
[docs] def restart_generators(self, of_loader=None): if of_loader is None: of_loader = self._loaders.keys() else: assert of_loader in self._loaders.keys(), \ "Key {} not in loaders ({})".format(of_loader, list(self._loaders)) of_loader = pyu.to_iterable(of_loader) self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__() for from_loader in of_loader}) return self
[docs] def wrap_batch(self, batch, from_loader=None, requires_grad=False, volatile=False): base_device_ordinal = \ self._base_device_ordinal if hasattr(self, '_base_device_ordinal') else None # First, send to the right device if base_device_ordinal is None: # Both inputs and labels are sent to the device batch = self.to_device(batch) elif base_device_ordinal == -1: # Input batches go to device, while labels remain on the CPU. # To start, we need the number of input batches, i.e. from_loader must not be None assert_(from_loader is not None, "`from_loader` needs to be specified if base_device_ordinal is -1 " "(i.e. base device for data-parallel training is CPU).", ValueError) loader_spec = self._loader_specs.get(from_loader) assert_(loader_spec is not None, "No `loader_spec` found for loader key '{}'.".format(from_loader), RuntimeError) # Get number of targets num_targets = loader_spec['num_targets'] # Fetch input batches and send'em to device (leave the targets alone) inputs = batch[:-num_targets] inputs = self.to_device(inputs) # Finally, build the batch batch = inputs + batch[-num_targets:] else: raise ValueError("Internal Error: Invalid base_device_ordinal: {}." .format(base_device_ordinal)) # Cast to the right dtype batch = self.cast(batch) # Second, wrap as variable batch = type(batch)([Variable(_batch, requires_grad=requires_grad, volatile=volatile) for _batch in batch]) return batch
[docs] def next_iteration(self): self._iteration_count += 1
[docs] def next_epoch(self): # Callback before the end of epoch self.callbacks.call(self.callbacks.END_OF_EPOCH, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count) self._epoch_count += 1 self._batch_count = 0 # Callback after the start of epoch self.callbacks.call(self.callbacks.BEGIN_OF_EPOCH, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count)
[docs] def stop_fitting(self, max_num_iterations=None, max_num_epochs=None): # First priority to iteration count if max_num_iterations is not None or max_num_epochs is None: max_num_iterations = \ self._max_num_iterations if max_num_iterations is None else max_num_iterations assert_(max_num_iterations is not None, "Neither max_num_iterations nor max_num_epochs was set.", RuntimeError) return self._iteration_count >= max_num_iterations else: # max_num_epochs is specified. It could be 'auto', in which case we read from the # class attribute max_num_epochs = self._max_num_epochs \ if isinstance(max_num_epochs, str) and max_num_epochs.lower() == 'auto' \ else max_num_epochs return self._epoch_count >= max_num_epochs
INF_STRINGS = {'inf', 'infinity', 'infty'}
[docs] def set_max_num_iterations(self, max_num_iterations): """ Set the maximum number of training iterations. Parameters ---------- max_num_iterations : int or float or str Maximum number of training iterations. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. Returns ------- Trainer self """ max_num_iterations = \ inf if max_num_iterations in self.INF_STRINGS else max_num_iterations # Validate type assert_(isinstance(max_num_iterations, int) or max_num_iterations == inf, "max_num_iterations must be an integer or numpy.inf, got {} instead." .format(type(max_num_iterations).__name__), TypeError) self._max_num_iterations = max_num_iterations return self
[docs] def set_max_num_epochs(self, max_num_epochs): """ Set the maximum number of training epochs. Parameters ---------- max_num_epochs : int or float or str Maximum number of training epochs. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. Returns ------- Trainer self """ max_num_epochs = inf if max_num_epochs in self.INF_STRINGS else max_num_epochs assert_(isinstance(max_num_epochs, int) or max_num_epochs == inf, "max_num_epochs must be an integer or numpy.inf, got {} instead." .format(type(max_num_epochs).__name__), TypeError) self._max_num_epochs = max_num_epochs return self
[docs] def fit(self, max_num_iterations=None, max_num_epochs=None): """ Fit model. Parameters ---------- max_num_iterations : int or float or str (Optional) Maximum number of training iterations. Overrides the value set by `Trainer.set_max_num_iterations`. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. max_num_epochs : int or float or str (Optional) Maximum number of training epochs. Overrides the value set by `Trainer.set_max_num_epochs`. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. Returns ------- Trainer self """ # Takes care of: # - dispatching train # - validation # - learning rate scheduling # - saving max_num_iterations = inf if max_num_iterations in self.INF_STRINGS else max_num_iterations max_num_iterations = self._max_num_iterations if max_num_iterations is None \ else max_num_iterations max_num_epochs = inf if max_num_epochs in self.INF_STRINGS else max_num_epochs max_num_epochs = self._max_num_epochs if max_num_epochs is None else max_num_epochs self.callbacks.call(self.callbacks.BEGIN_OF_FIT, max_num_iterations=max_num_iterations, max_num_epochs=max_num_epochs) # Local clock run_num = 0 while True: if self.stop_fitting(max_num_iterations, max_num_epochs): self.console.info("Exceeded max number of iterations / epochs, breaking.") break # Train self.train_for(break_callback=lambda *args: self.stop_fitting(max_num_iterations, max_num_epochs)) # Check if it's time to validate if self.validate_now: self.console.info("Validating.") self.validate_for() # Check if it's time to save if self.save_now: self.console.info("Saving.") self.save() run_num += 1 # Call callback self.callbacks.call(self.callbacks.END_OF_FIT, max_num_iterations=max_num_iterations, max_num_epochs=max_num_epochs, num_runs=run_num) return self
[docs] def apply_model_and_loss(self, inputs, target, backward=True): # Compute prediction prediction = self.apply_model(*inputs) # Compute loss kwargs = {} if (isinstance(self.criterion, torch.nn.Module) and 'trainer' in signature(self.criterion.forward).parameters): kwargs['trainer'] = self loss = self.criterion(prediction, target, **kwargs) if backward: # Backprop if required loss.backward() return prediction, loss
[docs] def train_for(self, num_iterations=None, break_callback=None): # Switch model to train mode self.train_mode() # Call callback self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_RUN, num_iterations=num_iterations) # iteration_num is a local clock. There's the global self._iteration_count that keeps # actual track of the number of iterations - this is updated by the call to # self.next_iteration(). iteration_num = 0 while True: if num_iterations is not None and iteration_num > num_iterations: self.console.info("Finished {} iterations. Breaking...".format(num_iterations)) break # Break if break callback asks us to if break_callback is not None and break_callback(iteration_num): self.console.info("Breaking on request from callback.") break self.console.progress("Training iteration {} (batch {} of epoch {})." .format(iteration_num, self._batch_count, self._epoch_count)) # Call callback self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_ITERATION, iteration_num=iteration_num) # Zero out the grads self.optimizer.zero_grad() # No interrupts while computing - a SIGINT could shoot down the driver if # done at the wrong time. Not sure if this has something to do with pinned memory with pyu.delayed_keyboard_interrupt(): # Get batch batch = self.fetch_next_batch('train') # Send to device and wrap as variable batch = self.wrap_batch(batch, from_loader='train') # Separate inputs from targets inputs, target = self.split_batch(batch, from_loader='train') # Apply model, compute loss and backprop prediction, loss = self.apply_model_and_loss(inputs, target, backward=True) # Compute metric if self.metric_is_defined and self.evaluate_metric_now: self._last_metric_evaluated_at_epoch = self._epoch_count error = self.metric(thu.unwrap(prediction, to_cpu=False), thu.unwrap(target, to_cpu=False)) self.update_state('training_error', thu.unwrap(error)) else: error = None # Update state from computation self.update_state('training_inputs', thu.unwrap(inputs)) self.update_state('training_target', thu.unwrap(target)) self.update_state('training_prediction', thu.unwrap(prediction)) self.update_state('training_loss', thu.unwrap(loss)) # Update state from model's state hooks self.update_state_from_model_state_hooks() # Update parameters self.optimizer.step() # Call callback self.callbacks.call(self.callbacks.END_OF_TRAINING_ITERATION, iteration_num=iteration_num) # Prepare for next iteration self.next_iteration() # Break if validating or saving. It's important that the next_iteration() method is # called before checking validate_now and save_now - because otherwise, the iteration # counter is never updated after the first save and validate, resulting in an infinite # save + validate loop. if self.validate_now: self.console.info("Breaking to validate.") break if self.save_now: self.console.info("Breaking to save.") break iteration_num += 1 self.callbacks.call(self.callbacks.END_OF_TRAINING_RUN, num_iterations=num_iterations) return self
[docs] def validate_for(self, num_iterations=None, loader_name='validate'): """ Validate for a given number of validation (if `num_iterations is not None`) or over the entire (validation) data set. Parameters ---------- num_iterations : int Number of iterations to validate for. To validate on the entire dataset, leave this as `None`. loader_name : str Name of the data loader to use for validation. 'validate' is the obvious default. Returns ------- Trainer self. """ assert_(loader_name in ['validate', 'test', 'train'], "Invalid `loader_name`: {}".format(loader_name), ValueError) # Average over errors validation_error_meter = tu.AverageMeter() validation_loss_meter = tu.AverageMeter() iteration_num = 0 num_iterations = \ self._num_validation_iterations if num_iterations is None else num_iterations # Switch to eval mode (e.g. for batchnorm, etc.) self.eval_mode() if loader_name not in self._loader_iters: self._loader_iters.update({loader_name: self._loaders[loader_name].__iter__()}) # If we don't know num_iterations, we're validating the entire dataset - so we might as # well restart the loader now if num_iterations is None: self.restart_generators(loader_name) # Record the epoch we're validating in self._last_validated_at_epoch = self._epoch_count self._last_validated_at_iteration = self._iteration_count self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_RUN, num_iterations=num_iterations, num_iterations_in_generator=len(self._loader_iters[loader_name]), last_validated_at_epoch=self._last_validated_at_epoch) while True: if num_iterations is not None and iteration_num > num_iterations: break self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_ITERATION, iteration_num=iteration_num) try: batch = self.fetch_next_batch(loader_name, restart_exhausted_generators= num_iterations is not None, update_batch_count=False, update_epoch_count_if_generator_exhausted=False) except StopIteration: self.console.info("{} generator exhausted, breaking.".format(loader_name)) break self.console.progress("Validating iteration {}.".format(iteration_num)) # Delay SIGINTs till after computation with pyu.delayed_keyboard_interrupt(): # Wrap batch = self.wrap_batch(batch, from_loader=loader_name, volatile=True) # Separate inputs, target = self.split_batch(batch, from_loader=loader_name) # Apply model, compute loss output, loss = self.apply_model_and_loss(inputs, target, backward=False) if isinstance(target, (list,tuple)): batch_size = target[0].size(self._target_batch_dim) else: batch_size = target.size(self._target_batch_dim) validation_loss_meter.update(loss.data[0], n=batch_size) # Compute validation_error if self.metric_is_defined: validation_error = self.metric(thu.unwrap(output, to_cpu=False), thu.unwrap(target, to_cpu=False)) if torch.is_tensor(validation_error): # Convert to float validation_error = validation_error[0] self.update_state('validation_error', thu.unwrap(validation_error)) validation_error_meter.update(validation_error, n=batch_size) self.update_state('validation_inputs', thu.unwrap(inputs)) self.update_state('validation_target', thu.unwrap(target)) self.update_state('validation_prediction', thu.unwrap(output)) self.update_state('validation_loss', thu.unwrap(loss)) # This is here for legacy reasons and will eventually be deprecated. self.update_state('validation_input', self.get_state('validation_inputs')) # Update from model's state hooks self.update_state_from_model_state_hooks() self.callbacks.call(self.callbacks.END_OF_VALIDATION_ITERATION, iteration_num=iteration_num) iteration_num += 1 self.console.info("Done validating. Logging results...") # Report validation_results = { 'validation_loss': validation_loss_meter.avg, 'validation_error': (validation_error_meter.avg if self.metric_is_defined else None) } self.record_validation_results(**validation_results) self.console.info("Validation loss: {validation_loss}; validation error: {validation_error}".format(**validation_results)) self.callbacks.call(self.callbacks.END_OF_VALIDATION_RUN, validation_loss_meter=validation_loss_meter, validation_error_meter= validation_error_meter if self.metric_is_defined else None) return self
[docs] def record_validation_results(self, validation_loss, validation_error): # Update state self.update_state('validation_loss_averaged', thu.unwrap(validation_loss)) if validation_error is not None: self.update_state('validation_error_averaged', thu.unwrap(validation_error)) # Prefer the error metric (if provided). This should be handled with care - # validation error should either always not be None, or otherwise. validation_score = validation_loss if validation_error is None else validation_error # Check if validation error is less than the best so far if self._best_validation_score is None or validation_score < self._best_validation_score: # Best score so far. The following flag will trigger a save self._is_iteration_with_best_validation_score = True self._best_validation_score = validation_score
[docs] def get_config(self, exclude_loader=True): # Returns a config dictionary, like __getstate__. Except optionally without the # data loaders (which might be yuuuuuge if it contains the data) config_dict = dict(self.__dict__) # Loader iterators can't be pickled if '_loader_iters' in config_dict: config_dict.update({'_loader_iters': {}}) if exclude_loader: if '_loaders' in config_dict: config_dict.update({'_loaders': {}}) return config_dict
[docs] def set_config(self, config_dict): # TODO some sanity checks on config_dict (e.g. whether the model is actually a model, etc) self.__dict__.update(config_dict) # Rebind trainer to callback engine self.callbacks.bind_trainer(self) # Have callback engine rebind all callbacks to trainer self.callbacks.rebind_trainer_to_all_callbacks() return self
[docs] def save(self, exclude_loader=True, stash_best_checkpoint=True): # Log the epoch for save_now self._last_saved_at_epoch = self._epoch_count self.callbacks.call(self.callbacks.BEGIN_OF_SAVE, save_to_directory=self._save_to_directory, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count, is_iteration_with_best_validation_score=self._is_iteration_with_best_validation_score) checkpoint_path = os.path.join(self._save_to_directory, self._checkpoint_filename) best_checkpoint_path = os.path.join(self._save_to_directory, self._best_checkpoint_filename) # Save the state dictionary torch.save(self.get_config(exclude_loader=exclude_loader), checkpoint_path, pickle_module=dill) self.callbacks.call(self.callbacks.END_OF_SAVE, save_to_directory=self._save_to_directory, checkpoint_path=checkpoint_path, best_checkpoint_path=best_checkpoint_path, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count, is_iteration_with_best_validation_score=self._is_iteration_with_best_validation_score) if self._is_iteration_with_best_validation_score and stash_best_checkpoint: # Do the stashin' shutil.copyfile(checkpoint_path, best_checkpoint_path) # This is required to prevent an infinite save loop? self._is_iteration_with_best_validation_score = False self.console.info("Saved to {}.".format(self._save_to_directory)) return self
[docs] def save_model(self, to_directory=None): to_directory = self._save_to_directory if to_directory is None else to_directory # Save the state dictionary torch.save(self.model, os.path.join(to_directory, 'model.pytorch'), pickle_module=dill) return self
[docs] def load(self, from_directory=None, best=False, filename=None): """ Load the trainer from checkpoint. Parameters ---------- from_directory : str Path to the directory where the checkpoint is located. The filename should be 'checkpoint.pytorch' if best=False, or 'best_checkpoint.pytorch' if best=True. best : bool Whether to load the best checkpoint. The filename in `from_directory` should be 'best_checkpoint.pytorch'. filename : str Overrides the default filename. Returns ------- Trainer self """ from_directory = self._save_to_directory if from_directory is None else from_directory assert from_directory is not None, "Nowhere to load from." # Get file name if filename is None: filename = self._best_checkpoint_filename if best else self._checkpoint_filename # Load the dictionary config_dict = torch.load(os.path.join(from_directory, filename), pickle_module=dill) # This is required to prevent an infinite save loop? self._is_iteration_with_best_validation_score = False # Set config self.set_config(config_dict) return self
[docs] def load_model(self, from_directory=None, filename=None): from_directory = self._save_to_directory if from_directory is None else from_directory filename = 'model.pytorch' if filename is None else filename # Load the model model = torch.load(os.path.join(from_directory, filename), pickle_module=dill) # Set model self.model = model return self
[docs] def load_(self, *args, **kwargs): # Here for legacy reasons - use load instead. return self.load(*args, **kwargs)
[docs] @pyu.deprecated("please use self.console.{info,progress,warning,debug} instead") def print(self, message): print("[+][{}] {}".format(str(datetime.now()), message))
[docs] @classmethod def build(cls, model=None, **trainer_config): """Factory function to build the trainer.""" # Check if trainer is to be loaded from file if trainer_config.get('load_from_checkpoint'): # Load checkpoint config trainer = cls(model).save_every(**trainer_config.get('checkpoint_config')) trainer.load_() else: trainer = cls(model) if 'logger_config' in trainer_config: trainer.build_logger(**trainer_config.get('logger_config')) if 'criterion_config' in trainer_config: trainer.build_criterion(**trainer_config.get('criterion_config')) if 'optimizer_config' in trainer_config: trainer.build_optimizer(**trainer_config.get('optimizer_config')) if 'metric_config' in trainer_config: trainer.build_metric(**trainer_config.get('metric_config')) if 'checkpoint_config' in trainer_config: trainer.save_every(**trainer_config.get('checkpoint_config')) if 'validation_config' in trainer_config: trainer.validate_every(**trainer_config.get('validation_config')) if 'max_num_iterations' in trainer_config: trainer.set_max_num_iterations(trainer_config.get('max_num_iterations')) if 'max_num_epochs' in trainer_config: trainer.set_max_num_epochs(trainer_config.get('max_num_epochs')) if trainer_config.get('use_cuda'): devices = trainer_config.get('use_cuda').get('devices') \ if isinstance(trainer_config.get('use_cuda'), dict) else None trainer.cuda(devices=devices) if 'training_precision' in trainer_config: trainer.set_precision(trainer_config.get('training_precision')) return trainer