Source code for inferno.trainers.callbacks.base

from ...utils import python_utils as pyu

[docs]class CallbackEngine(object): """ Gathers and manages callbacks. Callbacks are callables which are to be called by trainers when certain events ('triggers') occur. They could be any callable object, but if endowed with a `bind_trainer` method, it's called when the callback is registered. It is recommended that callbacks (or their `__call__` methods) use the double-star syntax for keyword arguments. """ # Triggers BEGIN_OF_FIT = 'begin_of_fit' END_OF_FIT = 'end_of_fit' BEGIN_OF_TRAINING_RUN = 'begin_of_training_run' END_OF_TRAINING_RUN = 'end_of_training_run' BEGIN_OF_EPOCH = 'begin_of_epoch' END_OF_EPOCH = 'end_of_epoch' BEGIN_OF_TRAINING_ITERATION = 'begin_of_training_iteration' END_OF_TRAINING_ITERATION = 'end_of_training_iteration' BEGIN_OF_VALIDATION_RUN = 'begin_of_validation_run' END_OF_VALIDATION_RUN = 'end_of_validation_run' BEGIN_OF_VALIDATION_ITERATION = 'begin_of_validation_iteration' END_OF_VALIDATION_ITERATION = 'end_of_validation_iteration' BEGIN_OF_SAVE = 'begin_of_save' END_OF_SAVE = 'end_of_save' TRIGGERS = {BEGIN_OF_FIT, END_OF_FIT, BEGIN_OF_TRAINING_RUN, END_OF_TRAINING_RUN, BEGIN_OF_EPOCH, END_OF_EPOCH, BEGIN_OF_TRAINING_ITERATION, END_OF_TRAINING_ITERATION, BEGIN_OF_VALIDATION_RUN, END_OF_VALIDATION_RUN, BEGIN_OF_VALIDATION_ITERATION, END_OF_VALIDATION_ITERATION, BEGIN_OF_SAVE, END_OF_SAVE} def __init__(self): self._trainer = None self._callback_registry = {trigger: set() for trigger in self.TRIGGERS} self._last_known_epoch = None self._last_known_iteration = None
[docs] def register_new_trigger(self, trigger_name): self.TRIGGERS.add(trigger_name) self._callback_registry.update({trigger_name: set()})
[docs] def bind_trainer(self, trainer): self._trainer = trainer return self
[docs] def unbind_trainer(self): self._trainer = None return self
@property def trainer_is_bound(self): return self._trainer is not None
[docs] def register_callback(self, callback, trigger='auto', bind_trainer=True): assert callable(callback) # Automatic callback registration based on their methods if trigger == 'auto': automatic_registration_successful = False for trigger in self.TRIGGERS: if pyu.has_callable_attr(callback, trigger): automatic_registration_successful = True self.register_callback(callback, trigger, bind_trainer) assert automatic_registration_successful, \ "Callback could not be auto-registered: no triggers recognized." return self # Validate triggers assert trigger in self.TRIGGERS # Add to callback registry self._callback_registry.get(trigger).add(callback) # Register trainer with the callback if required bind_trainer_to_callback = self.trainer_is_bound and \ bind_trainer and \ pyu.has_callable_attr(callback, 'bind_trainer') if bind_trainer_to_callback: callback.bind_trainer(self._trainer) return self
[docs] def rebind_trainer_to_all_callbacks(self): # FIXME This makes bind_trainer in register_callback reduntant, # especially if used by the trainer class, so... deprecate bind_traner. for callbacks_at_trigger in self._callback_registry.values(): for callback in callbacks_at_trigger: # Register trainer with the callback if required bind_trainer_to_callback = self.trainer_is_bound and \ pyu.has_callable_attr(callback, 'bind_trainer') if bind_trainer_to_callback: callback.bind_trainer(self._trainer)
[docs] def call(self, trigger, **kwargs): assert trigger in self.TRIGGERS kwargs.update({'trigger': trigger}) for callback in self._callback_registry.get(trigger): callback(**kwargs)
[docs] def get_config(self): # Pop trainer config_dict = dict(self.__dict__) config_dict.update({'_trainer': None}) return config_dict
[docs] def set_config(self, config_dict): self.__dict__.update(config_dict) return self
def __getstate__(self): return self.get_config() def __setstate__(self, state): self.set_config(state)
[docs]class Callback(object): """Recommended (but not required) base class for callbacks.""" def __init__(self): self._trainer = None self._debugging = False self.register_instance(self)
[docs] @classmethod def register_instance(cls, instance): if hasattr(cls, '_instance_registry') and instance not in cls._instance_registry: cls._instance_registry.append(instance) else: cls._instance_registry = [instance]
[docs] @classmethod def get_instances(cls): if hasattr(cls, '_instance_registry'): return pyu.from_iterable(cls._instance_registry) else: return None
@property def trainer(self): return self._trainer
[docs] def bind_trainer(self, trainer): self._trainer = trainer return self
[docs] def unbind_trainer(self): self._trainer = None return self
def __call__(self, **kwargs): if 'trigger' in kwargs: if hasattr(self, kwargs.get('trigger')) and \ callable(getattr(self, kwargs.get('trigger'))): getattr(self, kwargs.get('trigger'))(**kwargs)
[docs] def get_config(self): config_dict = dict(self.__dict__) config_dict.update({'_trainer': None}) return config_dict
[docs] def set_config(self, config_dict): self.__dict__.update(config_dict) return self
def __getstate__(self): return self.get_config() def __setstate__(self, state): self.set_config(state)
[docs] def toggle_debug(self): self._debugging = not self._debugging return self
[docs] def debug_print(self, message): if self._debugging: self.trainer.print("[DEBUG::{}] {}".format(type(self).__name__, message))