|
a |
|
b/utils/losses.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
def dice_loss(score, target): |
|
|
6 |
target = target.float() |
|
|
7 |
smooth = 1e-5 |
|
|
8 |
intersect = torch.sum(score * target) |
|
|
9 |
y_sum = torch.sum(target * target) |
|
|
10 |
z_sum = torch.sum(score * score) |
|
|
11 |
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) |
|
|
12 |
loss = 1 - loss |
|
|
13 |
return loss |
|
|
14 |
|
|
|
15 |
def Binary_dice_loss(predictive, target, ep=1e-8): |
|
|
16 |
intersection = 2 * torch.sum(predictive * target) + ep |
|
|
17 |
union = torch.sum(predictive) + torch.sum(target) + ep |
|
|
18 |
loss = 1 - intersection / union |
|
|
19 |
return loss |
|
|
20 |
|
|
|
21 |
def kl_loss(inputs, targets, ep=1e-8): |
|
|
22 |
kl_loss=nn.KLDivLoss(reduction='mean') |
|
|
23 |
consist_loss = kl_loss(torch.log(inputs+ep), targets) |
|
|
24 |
return consist_loss |
|
|
25 |
|
|
|
26 |
def soft_ce_loss(inputs, target, ep=1e-8): |
|
|
27 |
logprobs = torch.log(inputs+ep) |
|
|
28 |
return torch.mean(-(target[:,0,...]*logprobs[:,0,...]+target[:,1,...]*logprobs[:,1,...])) |
|
|
29 |
|
|
|
30 |
def softmax_kl_loss(input_logits, target_logits, sigmoid=False): |
|
|
31 |
"""Takes softmax on both sides and returns KL divergence |
|
|
32 |
|
|
|
33 |
Note: |
|
|
34 |
- Returns the sum over all examples. Divide by the batch size afterwards |
|
|
35 |
if you want the mean. |
|
|
36 |
- Sends gradients to inputs but not the targets. |
|
|
37 |
""" |
|
|
38 |
assert input_logits.size() == target_logits.size() |
|
|
39 |
if sigmoid: |
|
|
40 |
input_log_softmax = torch.log(torch.sigmoid(input_logits)) |
|
|
41 |
target_softmax = torch.sigmoid(target_logits) |
|
|
42 |
else: |
|
|
43 |
input_log_softmax = F.log_softmax(input_logits, dim=1) |
|
|
44 |
target_softmax = F.softmax(target_logits, dim=1) |
|
|
45 |
|
|
|
46 |
# return F.kl_div(input_log_softmax, target_softmax) |
|
|
47 |
kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') |
|
|
48 |
# mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) |
|
|
49 |
return kl_div |
|
|
50 |
|
|
|
51 |
def softmax_mse_loss(input_logits, target_logits): |
|
|
52 |
"""Takes softmax on both sides and returns MSE loss |
|
|
53 |
|
|
|
54 |
Note: |
|
|
55 |
- Returns the sum over all examples. Divide by the batch size afterwards |
|
|
56 |
if you want the mean. |
|
|
57 |
- Sends gradients to inputs but not the targets. |
|
|
58 |
""" |
|
|
59 |
assert input_logits.size() == target_logits.size() |
|
|
60 |
input_softmax = F.softmax(input_logits, dim=1) |
|
|
61 |
target_softmax = F.softmax(target_logits, dim=1) |
|
|
62 |
|
|
|
63 |
mse_loss = F.mse_loss(input_softmax,target_softmax) |
|
|
64 |
return mse_loss |
|
|
65 |
|
|
|
66 |
def mse_loss(input1, input2): |
|
|
67 |
return torch.mean((input1 - input2)**2) |
|
|
68 |
|
|
|
69 |
class DiceLoss(nn.Module): |
|
|
70 |
def __init__(self, n_classes): |
|
|
71 |
super(DiceLoss, self).__init__() |
|
|
72 |
self.n_classes = n_classes |
|
|
73 |
|
|
|
74 |
def _one_hot_encoder(self, input_tensor): |
|
|
75 |
tensor_list = [] |
|
|
76 |
for i in range(self.n_classes): |
|
|
77 |
temp_prob = input_tensor == i * torch.ones_like(input_tensor) |
|
|
78 |
tensor_list.append(temp_prob) |
|
|
79 |
output_tensor = torch.cat(tensor_list, dim=1) |
|
|
80 |
return output_tensor.float() |
|
|
81 |
|
|
|
82 |
def _dice_loss(self, score, target): |
|
|
83 |
target = target.float() |
|
|
84 |
smooth = 1e-10 |
|
|
85 |
intersection = torch.sum(score * target) |
|
|
86 |
union = torch.sum(score * score) + torch.sum(target * target) + smooth |
|
|
87 |
loss = 1 - intersection / union |
|
|
88 |
return loss |
|
|
89 |
|
|
|
90 |
def forward(self, inputs, target, weight=None, softmax=False): |
|
|
91 |
if softmax: |
|
|
92 |
inputs = torch.softmax(inputs, dim=1) |
|
|
93 |
target = self._one_hot_encoder(target) |
|
|
94 |
if weight is None: |
|
|
95 |
weight = [1] * self.n_classes |
|
|
96 |
assert inputs.size() == target.size(), 'predict & target shape do not match' |
|
|
97 |
class_wise_dice = [] |
|
|
98 |
loss = 0.0 |
|
|
99 |
for i in range(0, self.n_classes): |
|
|
100 |
dice = self._dice_loss(inputs[:, i], target[:, i]) |
|
|
101 |
class_wise_dice.append(1.0 - dice.item()) |
|
|
102 |
loss += dice * weight[i] |
|
|
103 |
return loss / self.n_classes |