Source code for inferno.extensions.criteria.set_similarity_measures

import torch.nn as nn
from ...utils.torch_utils import flatten_samples
from torch.autograd import Variable


__all__ = ['SorensenDiceLoss']


[docs]class SorensenDiceLoss(nn.Module): """ Computes a loss scalar, which when minimized maximizes the Sorensen-Dice similarity between the input and the target. For both inputs and targets it must be the case that `input_or_target.size(1) = num_channels`. """ def __init__(self, weight=None, channelwise=True, eps=1e-6): """ Parameters ---------- weight : torch.FloatTensor or torch.cuda.FloatTensor Class weights. Applies only if `channelwise = True`. channelwise : bool Whether to apply the loss channelwise and sum the results (True) or to apply it on all channels jointly (False). """ super(SorensenDiceLoss, self).__init__() self.register_buffer('weight', weight) self.channelwise = channelwise self.eps = eps
[docs] def forward(self, input, target): if not self.channelwise: numerator = (input * target).sum() denominator = (input * input).sum() + (target * target).sum() loss = -2. * (numerator / denominator.clamp(min=self.eps)) else: # TODO This should be compatible with Pytorch 0.2, but check # Flatten input and target to have the shape (C, N), # where N is the number of samples input = flatten_samples(input) target = flatten_samples(target) # Compute numerator and denominator (by summing over samples and # leaving the channels intact) numerator = (input * target).sum(-1) denominator = (input * input).sum(-1) + (target * target).sum(-1) channelwise_loss = -2 * (numerator / denominator.clamp(min=self.eps)) if self.weight is not None: # With pytorch < 0.2, channelwise_loss.size = (C, 1). if channelwise_loss.dim() == 2: channelwise_loss = channelwise_loss.squeeze(1) # Wrap weights in a variable weight = Variable(self.weight, requires_grad=False) # Apply weight channelwise_loss = weight * channelwise_loss # Sum over the channels to compute the total loss loss = channelwise_loss.sum() return loss