--- a
+++ b/losses.py
@@ -0,0 +1,309 @@
+import torch
+import torch.nn.functional as F
+from torch import nn as nn
+from torch.autograd import Variable
+from torch.nn import MSELoss, SmoothL1Loss, L1Loss
+
+def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
+    """
+    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
+    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
+
+    Args:
+         input (torch.Tensor): NxCxSpatial input tensor
+         target (torch.Tensor): NxCxSpatial target tensor
+         epsilon (float): prevents division by zero
+         weight (torch.Tensor): Cx1 tensor of weight per channel/class
+    """
+
+    # input and target shapes must match
+    assert input.size() == target.size(), "'input' and 'target' must have the same shape"
+
+    input = flatten(input)
+    target = flatten(target)
+    target = target.float()
+
+    # compute per channel Dice Coefficient
+    intersect = (input * target).sum(-1)
+    if weight is not None:
+        intersect = weight * intersect
+
+    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
+    denominator = (input * input).sum(-1) + (target * target).sum(-1)
+    return 2 * (intersect / denominator.clamp(min=epsilon))
+
+
+class _MaskingLossWrapper(nn.Module):
+    """
+    Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`.
+    """
+
+    def __init__(self, loss, ignore_index):
+        super(_MaskingLossWrapper, self).__init__()
+        assert ignore_index is not None, 'ignore_index cannot be None'
+        self.loss = loss
+        self.ignore_index = ignore_index
+
+    def forward(self, input, target):
+        mask = target.clone().ne_(self.ignore_index)
+        mask.requires_grad = False
+
+        # mask out input/target so that the gradient is zero where on the mask
+        input = input * mask
+        target = target * mask
+
+        # forward masked input and target to the loss
+        return self.loss(input, target)
+
+    
+
+class SkipLastTargetChannelWrapper(nn.Module):
+    """
+    Loss wrapper which removes additional target channel
+    """
+
+    def __init__(self, loss, squeeze_channel=False):
+        super(SkipLastTargetChannelWrapper, self).__init__()
+        self.loss = loss
+        self.squeeze_channel = squeeze_channel
+
+    def forward(self, input, target):
+        assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel'
+
+        # skips last target channel if needed
+        target = target[:, :-1, ...]
+
+        if self.squeeze_channel:
+            # squeeze channel dimension if singleton
+            target = torch.squeeze(target, dim=1)
+        return self.loss(input, target)
+
+
+class _AbstractDiceLoss(nn.Module):
+    """
+    Base class for different implementations of Dice loss.
+    """
+
+    def __init__(self, weight=None, normalization='sigmoid'):
+        super(_AbstractDiceLoss, self).__init__()
+        self.register_buffer('weight', weight)
+        # The output from the network during training is assumed to be un-normalized probabilities and we would
+        # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
+        # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
+        # However if one would like to apply Softmax in order to get the proper probability distribution from the
+        # output, just specify `normalization=Softmax`
+        assert normalization in ['sigmoid', 'softmax', 'none']
+        if normalization == 'sigmoid':
+            self.normalization = nn.Sigmoid()
+        elif normalization == 'softmax':
+            self.normalization = nn.Softmax(dim=1)
+        else:
+            self.normalization = lambda x: x
+
+    def dice(self, input, target, weight):
+        # actual Dice score computation; to be implemented by the subclass
+        raise NotImplementedError
+
+    def forward(self, input, target):
+        # get probabilities from logits
+        input = self.normalization(input)
+
+        # compute per channel Dice coefficient
+        per_channel_dice = self.dice(input, target, weight=self.weight)
+
+        # average Dice score across all channels/classes
+        return 1. - torch.mean(per_channel_dice)
+
+
+class DiceLoss(_AbstractDiceLoss):
+    """Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
+    For multi-class segmentation `weight` parameter can be used to assign different weights per class.
+    The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
+    """
+
+    def __init__(self, weight=None, normalization='sigmoid'):
+        super().__init__(weight, normalization)
+
+    def dice(self, input, target, weight):
+        return compute_per_channel_dice(input, target, weight=self.weight)
+
+
+class GeneralizedDiceLoss(_AbstractDiceLoss):
+    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
+    """
+
+    def __init__(self, normalization='sigmoid', epsilon=1e-6):
+        super().__init__(weight=None, normalization=normalization)
+        self.epsilon = epsilon
+
+    def dice(self, input, target, weight):
+        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
+
+        input = flatten(input)
+        target = flatten(target)
+        target = target.float()
+
+        if input.size(0) == 1:
+            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
+            # put foreground and background voxels in separate channels
+            input = torch.cat((input, 1 - input), dim=0)
+            target = torch.cat((target, 1 - target), dim=0)
+
+        # GDL weighting: the contribution of each label is corrected by the inverse of its volume
+        w_l = target.sum(-1)
+        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
+        w_l.requires_grad = False
+
+        intersect = (input * target).sum(-1)
+        intersect = intersect * w_l
+
+        denominator = (input + target).sum(-1)
+        denominator = (denominator * w_l).clamp(min=self.epsilon)
+
+        return 2 * (intersect.sum() / denominator.sum())
+
+
+class BCEDiceLoss(nn.Module):
+    """Linear combination of BCE and Dice losses"""
+
+    def __init__(self, alpha, beta):
+        super(BCEDiceLoss, self).__init__()
+        self.alpha = alpha
+        self.bce = nn.BCEWithLogitsLoss()
+        self.beta = beta
+        self.dice = DiceLoss()
+
+    def forward(self, input, target):
+        return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)
+
+
+class WeightedCrossEntropyLoss(nn.Module):
+    """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
+    """
+
+    def __init__(self, ignore_index=-1):
+        super(WeightedCrossEntropyLoss, self).__init__()
+        self.ignore_index = ignore_index
+
+    def forward(self, input, target):
+        weight = self._class_weights(input)
+        return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index)
+
+    @staticmethod
+    def _class_weights(input):
+        # normalize the input first
+        input = F.softmax(input, dim=1)
+        flattened = flatten(input)
+        nominator = (1. - flattened).sum(-1)
+        denominator = flattened.sum(-1)
+        class_weights = Variable(nominator / denominator, requires_grad=False)
+        return class_weights
+
+
+class WeightedSmoothL1Loss(nn.SmoothL1Loss):
+    def __init__(self, threshold, initial_weight, apply_below_threshold=True):
+        super().__init__(reduction="none")
+        self.threshold = threshold
+        self.apply_below_threshold = apply_below_threshold
+        self.weight = initial_weight
+
+    def forward(self, input, target):
+        l1 = super().forward(input, target)
+
+        if self.apply_below_threshold:
+            mask = target < self.threshold
+        else:
+            mask = target >= self.threshold
+
+        l1[mask] = l1[mask] * self.weight
+
+        return l1.mean()
+
+
+def flatten(tensor):
+    """Flattens a given tensor such that the channel axis is first.
+    The shapes are transformed as follows:
+       (N, C, D, H, W) -> (C, N * D * H * W)
+    """
+    # number of channels
+    C = tensor.size(1)
+    # new axis order
+    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
+    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
+    transposed = tensor.permute(axis_order)
+    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
+    return transposed.contiguous().view(C, -1)
+
+
+def get_loss_criterion(config):
+    """
+    Returns the loss function based on provided configuration
+    :param config: (dict) a top level configuration object containing the 'loss' key
+    :return: an instance of the loss function
+    """
+    assert 'loss' in config, 'Could not find loss function configuration'
+    loss_config = config['loss']
+    name = loss_config.pop('name')
+
+    ignore_index = loss_config.pop('ignore_index', None)
+    skip_last_target = loss_config.pop('skip_last_target', False)
+    weight = loss_config.pop('weight', None)
+
+    if weight is not None:
+        # convert to cuda tensor if necessary
+        weight = torch.tensor(weight).to(config['device'])
+
+    pos_weight = loss_config.pop('pos_weight', None)
+    if pos_weight is not None:
+        # convert to cuda tensor if necessary
+        pos_weight = torch.tensor(pos_weight).to(config['device'])
+
+    loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight)
+
+    if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']):
+        # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly
+        loss = _MaskingLossWrapper(loss, ignore_index)
+
+    if skip_last_target:
+        loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False))
+
+    return loss
+
+
+#######################################################################################################################
+
+def _create_loss(name, loss_config, weight, ignore_index, pos_weight):
+    if name == 'BCEWithLogitsLoss':
+        return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
+    elif name == 'BCEDiceLoss':
+        alpha = loss_config.get('alphs', 1.)
+        beta = loss_config.get('beta', 1.)
+        return BCEDiceLoss(alpha, beta)
+    elif name == 'CrossEntropyLoss':
+        if ignore_index is None:
+            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
+        return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
+    elif name == 'WeightedCrossEntropyLoss':
+        if ignore_index is None:
+            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
+        return WeightedCrossEntropyLoss(ignore_index=ignore_index)
+    elif name == 'PixelWiseCrossEntropyLoss':
+        return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index)
+    elif name == 'GeneralizedDiceLoss':
+        normalization = loss_config.get('normalization', 'sigmoid')
+        return GeneralizedDiceLoss(normalization=normalization)
+    elif name == 'DiceLoss':
+        normalization = loss_config.get('normalization', 'sigmoid')
+        return DiceLoss(weight=weight, normalization=normalization)
+    elif name == 'MSELoss':
+        return MSELoss()
+    elif name == 'SmoothL1Loss':
+        return SmoothL1Loss()
+    elif name == 'L1Loss':
+        return L1Loss()
+    elif name == 'WeightedSmoothL1Loss':
+        return WeightedSmoothL1Loss(threshold=loss_config['threshold'],
+                                    initial_weight=loss_config['initial_weight'],
+                                    apply_below_threshold=loss_config.get('apply_below_threshold', True))
+    else:
+        raise RuntimeError(f"Unsupported loss function: '{name}'")