Diff of /utils/losses.py [000000] .. [903821]

Switch to unified view

a b/utils/losses.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
def dice_loss(score, target):
6
    target = target.float()
7
    smooth = 1e-5
8
    intersect = torch.sum(score * target)
9
    y_sum = torch.sum(target * target)
10
    z_sum = torch.sum(score * score)
11
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
12
    loss = 1 - loss
13
    return loss
14
15
def Binary_dice_loss(predictive, target, ep=1e-8):
16
    intersection = 2 * torch.sum(predictive * target) + ep
17
    union = torch.sum(predictive) + torch.sum(target) + ep
18
    loss = 1 - intersection / union
19
    return loss
20
21
def kl_loss(inputs, targets, ep=1e-8):
22
    kl_loss=nn.KLDivLoss(reduction='mean')
23
    consist_loss = kl_loss(torch.log(inputs+ep), targets)
24
    return consist_loss
25
26
def soft_ce_loss(inputs, target, ep=1e-8):
27
    logprobs = torch.log(inputs+ep)
28
    return  torch.mean(-(target[:,0,...]*logprobs[:,0,...]+target[:,1,...]*logprobs[:,1,...]))
29
30
def softmax_kl_loss(input_logits, target_logits, sigmoid=False):
31
    """Takes softmax on both sides and returns KL divergence
32
33
    Note:
34
    - Returns the sum over all examples. Divide by the batch size afterwards
35
      if you want the mean.
36
    - Sends gradients to inputs but not the targets.
37
    """
38
    assert input_logits.size() == target_logits.size()
39
    if sigmoid:
40
        input_log_softmax = torch.log(torch.sigmoid(input_logits))
41
        target_softmax = torch.sigmoid(target_logits)
42
    else:
43
        input_log_softmax = F.log_softmax(input_logits, dim=1)
44
        target_softmax = F.softmax(target_logits, dim=1)
45
46
    # return F.kl_div(input_log_softmax, target_softmax)
47
    kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean')
48
    # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...])
49
    return kl_div
50
51
def softmax_mse_loss(input_logits, target_logits):
52
    """Takes softmax on both sides and returns MSE loss
53
54
    Note:
55
    - Returns the sum over all examples. Divide by the batch size afterwards
56
      if you want the mean.
57
    - Sends gradients to inputs but not the targets.
58
    """
59
    assert input_logits.size() == target_logits.size()
60
    input_softmax = F.softmax(input_logits, dim=1)
61
    target_softmax = F.softmax(target_logits, dim=1)
62
63
    mse_loss = F.mse_loss(input_softmax,target_softmax)
64
    return mse_loss
65
66
def mse_loss(input1, input2):
67
    return torch.mean((input1 - input2)**2)
68
69
class DiceLoss(nn.Module):
70
    def __init__(self, n_classes):
71
        super(DiceLoss, self).__init__()
72
        self.n_classes = n_classes
73
74
    def _one_hot_encoder(self, input_tensor):
75
        tensor_list = []
76
        for i in range(self.n_classes):
77
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
78
            tensor_list.append(temp_prob)
79
        output_tensor = torch.cat(tensor_list, dim=1)
80
        return output_tensor.float()
81
82
    def _dice_loss(self, score, target):
83
        target = target.float()
84
        smooth = 1e-10
85
        intersection = torch.sum(score * target)
86
        union = torch.sum(score * score) + torch.sum(target * target) + smooth
87
        loss = 1 - intersection / union
88
        return loss
89
90
    def forward(self, inputs, target, weight=None, softmax=False):
91
        if softmax:
92
            inputs = torch.softmax(inputs, dim=1)
93
        target = self._one_hot_encoder(target)
94
        if weight is None:
95
            weight = [1] * self.n_classes
96
        assert inputs.size() == target.size(), 'predict & target shape do not match'
97
        class_wise_dice = []
98
        loss = 0.0
99
        for i in range(0, self.n_classes):
100
            dice = self._dice_loss(inputs[:, i], target[:, i])
101
            class_wise_dice.append(1.0 - dice.item())
102
            loss += dice * weight[i]
103
        return loss / self.n_classes