Diff of /utils/losses.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/utils/losses.py
@@ -0,0 +1,103 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def dice_loss(score, target):
+    target = target.float()
+    smooth = 1e-5
+    intersect = torch.sum(score * target)
+    y_sum = torch.sum(target * target)
+    z_sum = torch.sum(score * score)
+    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
+    loss = 1 - loss
+    return loss
+
+def Binary_dice_loss(predictive, target, ep=1e-8):
+    intersection = 2 * torch.sum(predictive * target) + ep
+    union = torch.sum(predictive) + torch.sum(target) + ep
+    loss = 1 - intersection / union
+    return loss
+
+def kl_loss(inputs, targets, ep=1e-8):
+    kl_loss=nn.KLDivLoss(reduction='mean')
+    consist_loss = kl_loss(torch.log(inputs+ep), targets)
+    return consist_loss
+
+def soft_ce_loss(inputs, target, ep=1e-8):
+    logprobs = torch.log(inputs+ep)
+    return  torch.mean(-(target[:,0,...]*logprobs[:,0,...]+target[:,1,...]*logprobs[:,1,...]))
+
+def softmax_kl_loss(input_logits, target_logits, sigmoid=False):
+    """Takes softmax on both sides and returns KL divergence
+
+    Note:
+    - Returns the sum over all examples. Divide by the batch size afterwards
+      if you want the mean.
+    - Sends gradients to inputs but not the targets.
+    """
+    assert input_logits.size() == target_logits.size()
+    if sigmoid:
+        input_log_softmax = torch.log(torch.sigmoid(input_logits))
+        target_softmax = torch.sigmoid(target_logits)
+    else:
+        input_log_softmax = F.log_softmax(input_logits, dim=1)
+        target_softmax = F.softmax(target_logits, dim=1)
+
+    # return F.kl_div(input_log_softmax, target_softmax)
+    kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean')
+    # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...])
+    return kl_div
+
+def softmax_mse_loss(input_logits, target_logits):
+    """Takes softmax on both sides and returns MSE loss
+
+    Note:
+    - Returns the sum over all examples. Divide by the batch size afterwards
+      if you want the mean.
+    - Sends gradients to inputs but not the targets.
+    """
+    assert input_logits.size() == target_logits.size()
+    input_softmax = F.softmax(input_logits, dim=1)
+    target_softmax = F.softmax(target_logits, dim=1)
+
+    mse_loss = F.mse_loss(input_softmax,target_softmax)
+    return mse_loss
+
+def mse_loss(input1, input2):
+    return torch.mean((input1 - input2)**2)
+
+class DiceLoss(nn.Module):
+    def __init__(self, n_classes):
+        super(DiceLoss, self).__init__()
+        self.n_classes = n_classes
+
+    def _one_hot_encoder(self, input_tensor):
+        tensor_list = []
+        for i in range(self.n_classes):
+            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
+            tensor_list.append(temp_prob)
+        output_tensor = torch.cat(tensor_list, dim=1)
+        return output_tensor.float()
+
+    def _dice_loss(self, score, target):
+        target = target.float()
+        smooth = 1e-10
+        intersection = torch.sum(score * target)
+        union = torch.sum(score * score) + torch.sum(target * target) + smooth
+        loss = 1 - intersection / union
+        return loss
+
+    def forward(self, inputs, target, weight=None, softmax=False):
+        if softmax:
+            inputs = torch.softmax(inputs, dim=1)
+        target = self._one_hot_encoder(target)
+        if weight is None:
+            weight = [1] * self.n_classes
+        assert inputs.size() == target.size(), 'predict & target shape do not match'
+        class_wise_dice = []
+        loss = 0.0
+        for i in range(0, self.n_classes):
+            dice = self._dice_loss(inputs[:, i], target[:, i])
+            class_wise_dice.append(1.0 - dice.item())
+            loss += dice * weight[i]
+        return loss / self.n_classes