--- a
+++ b/src/losses/dice_loss.py
@@ -0,0 +1,104 @@
+"""
+Code was adapted and modified from https://github.com/black0017/MedicalZooPytorch
+"""
+from typing import Tuple
+import torch
+from src.losses import utils
+from torch import nn
+
+
+
+class DiceLoss(nn.Module):
+    """
+    Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
+    For multi-class segmentation `weight` parameter can be used to assign different weights per class.
+    """
+    def __init__(self, classes=4, weight=None, sigmoid_normalization=True, eval_regions: bool=False):
+        super(DiceLoss, self).__init__()
+
+        self.register_buffer('weight', weight)
+        self.normalization = nn.Sigmoid() if sigmoid_normalization else nn.Softmax(dim=1)
+        self.classes = classes
+        self.eval_regions = eval_regions
+
+    def _flatten(self, tensor: torch.tensor) -> torch.tensor:
+        """
+        Flattens a given tensor such that the channel axis is first.
+        The shapes are transformed as follows:
+           (N, C, D, H, W) -> (C, N * D * H * W)
+        """
+        C = tensor.size(1) # number of channels
+        axis_order = (1, 0) + tuple(range(2, tensor.dim())) # new axis order
+        transposed = tensor.permute(axis_order) # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
+        return transposed.contiguous().view(C, -1) # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
+
+    def _reformat_labels(self, seg_mask):
+        """
+        Input format: (batch_size, channels, D, H, W)
+        :param seg_mask:
+        :return:
+        """
+        wt = torch.stack([ seg_mask[:, 0, ...], torch.sum(seg_mask[:, [1, 2, 3], ...], dim=1)], dim=1)
+        tc = torch.stack([ seg_mask[:, 0, ...], torch.sum(seg_mask[:, [1, 3], ...], dim=1)], dim=1)
+        et = torch.stack([ seg_mask[:, 0, ...], seg_mask[:, 3, ...]], dim=1)
+        return wt, tc, et
+
+    def dice(self, input: torch.tensor, target: torch.tensor, weight: float, epsilon=1e-6) -> float:
+        """
+        Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
+        Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
+
+        :param input: NxCxSpatial input tensor
+        :param target:  NxCxSpatial target tensor
+        :param weight: Cx1 tensor of weight per channel. Channels represent the class
+        :param epsilon: prevents division by zero
+        :return: dice loss, dice score
+
+        """
+        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
+
+        input = self._flatten(input)
+        target = self._flatten(target)
+        target = target.float()
+
+        # Compute per channel Dice Coefficient
+        intersect = (input * target).sum(-1)
+        if weight is not None:
+            intersect = weight * intersect
+
+        union = (input * input).sum(-1) + (target * target).sum(-1)
+        return 2 * (intersect / union.clamp(min=epsilon))
+
+
+    def forward(self, input: torch.tensor, target: torch.tensor) -> Tuple[float, float, list]:
+
+        target = utils.expand_as_one_hot(target.long(), self.classes)
+
+        assert input.dim() == target.dim() == 5, f"'input' {input.dim()} and 'target' {target.dim()} have different number of dims "
+
+        input = self.normalization(input.float())
+
+        if self.eval_regions:
+            input_wt, input_tc, input_et = self._reformat_labels(input)
+            target_wt, target_tc, target_et = self._reformat_labels(target)
+
+            wt_dice = torch.mean(self.dice(input_wt, target_wt, weight=self.weight))
+            tc_dice = torch.mean(self.dice(input_tc, target_tc, weight=self.weight))
+            et_dice = torch.mean(self.dice(input_et, target_et, weight=self.weight))
+
+            wt_loss = 1 - wt_dice
+            tc_loss = 1 - tc_dice
+            et_loss = 1 - et_dice
+
+            loss = 1/3 * (wt_loss + tc_loss + et_loss)
+            score = 1/3 * (wt_dice + tc_dice + et_dice)
+
+            return loss, score, [wt_loss, tc_loss, et_loss]
+
+        else:
+            per_channel_dice = self.dice(input, target, weight=self.weight) # compute per channel Dice coefficient
+
+            mean = torch.mean(per_channel_dice)
+            loss = (1. - mean)
+            # average Dice score across all channels/classes
+            return loss, mean, per_channel_dice[1:]
\ No newline at end of file