Source code for inferno.extensions.initializers.base

import torch.nn.init as init


__all__ = ['Initializer',
           'Initialization',
           'WeightInitFunction',
           'BiasInitFunction',
           'TensorInitFunction']


[docs]class Initializer(object): """ Base class for all initializers. """ # TODO Support LSTMs and GRUs VALID_LAYERS = {'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'Linear', 'Bilinear', 'Embedding'} def __call__(self, module): module_class_name = module.__class__.__name__ if module_class_name in self.VALID_LAYERS: # Apply to weight and bias try: if hasattr(module, 'weight'): self.call_on_weight(module.weight.data) except NotImplementedError: # Don't cry if it's not implemented pass try: if hasattr(module, 'bias'): self.call_on_bias(module.bias.data) except NotImplementedError: pass return module
[docs] def call_on_bias(self, tensor): return self.call_on_tensor(tensor)
[docs] def call_on_weight(self, tensor): return self.call_on_tensor(tensor)
[docs] def call_on_tensor(self, tensor): raise NotImplementedError
[docs] @classmethod def initializes_weight(cls): return 'call_on_tensor' in cls.__dict__ or 'call_on_weight' in cls.__dict__
[docs] @classmethod def initializes_bias(cls): return 'call_on_tensor' in cls.__dict__ or 'call_on_bias' in cls.__dict__
[docs]class Initialization(Initializer): def __init__(self, weight_initializer=None, bias_initializer=None): if weight_initializer is None: self.weight_initializer = Initializer() else: if isinstance(weight_initializer, Initializer): assert weight_initializer.initializes_weight() self.weight_initializer = weight_initializer elif isinstance(weight_initializer, str): init_function = getattr(init, weight_initializer, None) assert init_function is not None self.weight_initializer = WeightInitFunction(init_function=init_function) else: # Provison for weight_initializer to be a function assert callable(weight_initializer) self.weight_initializer = WeightInitFunction(init_function=weight_initializer) if bias_initializer is None: self.bias_initializer = Initializer() else: if isinstance(bias_initializer, Initializer): assert bias_initializer.initializes_bias self.bias_initializer = bias_initializer elif isinstance(bias_initializer, str): init_function = getattr(init, bias_initializer, None) assert init_function is not None self.bias_initializer = BiasInitFunction(init_function=init_function) else: assert callable(bias_initializer) self.bias_initializer = BiasInitFunction(init_function=bias_initializer)
[docs] def call_on_weight(self, tensor): return self.weight_initializer.call_on_weight(tensor)
[docs] def call_on_bias(self, tensor): return self.bias_initializer.call_on_bias(tensor)
[docs]class WeightInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(WeightInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs
[docs] def call_on_weight(self, tensor): return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
[docs]class BiasInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(BiasInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs
[docs] def call_on_bias(self, tensor): return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
[docs]class TensorInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(TensorInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs
[docs] def call_on_tensor(self, tensor): return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)