Source code for inferno.extensions.initializers.presets

import numpy as np
import torch.nn.init as init
from torch.autograd import Variable
from functools import partial

from .base import Initialization, Initializer


__all__ = ['Constant', 'NormalWeights',
           'SELUWeightsZeroBias',
           'ELUWeightsZeroBias',
           'OrthogonalWeightsZeroBias',
           'KaimingNormalWeightsZeroBias']


[docs]class Constant(Initializer): """Initialize with a constant.""" def __init__(self, constant): self.constant = constant
[docs] def call_on_tensor(self, tensor): if isinstance(tensor, Variable): tensor = tensor.data tensor.fill_(self.constant) return tensor
[docs]class NormalWeights(Initializer): """ Initialize weights with random numbers drawn from the normal distribution at `mean` and `stddev`. """ def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None): self.mean = mean self.stddev = stddev self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in
[docs] def compute_fan_in(self, tensor): if tensor.dim() == 2: return tensor.size(1) else: return np.prod(list(tensor.size())[1:])
[docs] def call_on_weight(self, tensor): if isinstance(tensor, Variable): self.call_on_weight(tensor.data) return tensor # Compute stddev if required if self.sqrt_gain_over_fan_in is not None: stddev = self.stddev * \ np.sqrt(self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor)) else: stddev = self.stddev # Init tensor.normal_(self.mean, stddev)
[docs]class OrthogonalWeightsZeroBias(Initialization): def __init__(self, orthogonal_gain=1.): super(OrthogonalWeightsZeroBias, self)\ .__init__(weight_initializer=partial(init.orthogonal, gain=orthogonal_gain), bias_initializer=Constant(0.))
[docs]class KaimingNormalWeightsZeroBias(Initialization): def __init__(self, relu_leakage=0): super(KaimingNormalWeightsZeroBias, self)\ .__init__(weight_initializer=partial(init.kaiming_normal, a=relu_leakage), bias_initializer=Constant(0.))
[docs]class SELUWeightsZeroBias(Initialization): def __init__(self): super(SELUWeightsZeroBias, self)\ .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.), bias_initializer=Constant(0.))
[docs]class ELUWeightsZeroBias(Initialization): def __init__(self): super(ELUWeightsZeroBias, self)\ .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277), bias_initializer=Constant(0.))