|
a |
|
b/loss.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import torch.nn as nn |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
from torch.autograd import Function, Variable |
|
|
6 |
|
|
|
7 |
def cross_entropy2d(input, target, weight=None, size_average=True): |
|
|
8 |
|
|
|
9 |
n, c, h, w = input.size() |
|
|
10 |
nt, ct, ht, wt = target.size() |
|
|
11 |
''' |
|
|
12 |
# Handle inconsistent size between input and target |
|
|
13 |
if h > ht and w > wt: # upsample labels |
|
|
14 |
target = target.unsequeeze(1) |
|
|
15 |
target = F.upsample(target, size=(h, w), mode='nearest') |
|
|
16 |
target = target.sequeeze(1) |
|
|
17 |
elif h < ht and w < wt: # upsample images |
|
|
18 |
input = F.upsample(input, size=(ht, wt), mode='bilinear') |
|
|
19 |
elif h != ht and w != wt: |
|
|
20 |
raise Exception("Only support upsampling") |
|
|
21 |
''' |
|
|
22 |
log_p = F.log_softmax(input, dim=1) |
|
|
23 |
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) |
|
|
24 |
log_p = log_p[target.contiguous().view(-1, 1).repeat(1, c) >= 0] |
|
|
25 |
log_p = log_p.view(-1, c) |
|
|
26 |
|
|
|
27 |
mask = target >= 0 |
|
|
28 |
target = target[mask] |
|
|
29 |
loss = F.nll_loss(log_p, target, ignore_index=250, |
|
|
30 |
weight=weight, size_average=False) |
|
|
31 |
if size_average: |
|
|
32 |
loss /= mask.data.sum().float() |
|
|
33 |
return loss |
|
|
34 |
|
|
|
35 |
def loss_ce_t(input,target): |
|
|
36 |
#input=F.sigmoid(input) |
|
|
37 |
target_bin=Variable(torch.zeros(1,11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1)) |
|
|
38 |
return F.binary_cross_entropy_with_logits(input,target_bin) |
|
|
39 |
|
|
|
40 |
def dice_loss(input, target): |
|
|
41 |
target_bin=Variable(torch.zeros(target.shape[0],11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1)) |
|
|
42 |
smooth = 1. |
|
|
43 |
iflat = input.view(-1) |
|
|
44 |
tflat = target_bin.view(-1) |
|
|
45 |
intersection = (iflat * tflat).sum() |
|
|
46 |
return 1 - ((2. * intersection + smooth) / |
|
|
47 |
(iflat.sum() + tflat.sum() + smooth)) |
|
|
48 |
|
|
|
49 |
def weighted_loss(input,target,weight,size_average=True): |
|
|
50 |
n,c,h,w=input.size() |
|
|
51 |
target_bin=Variable(torch.zeros(n,c,h,w).cuda()).scatter_(1,target,1) |
|
|
52 |
target_bin=target_bin.transpose(1,2).transpose(2,3).contiguous().view(n*h*w,c).float() |
|
|
53 |
|
|
|
54 |
# NHWC |
|
|
55 |
input=F.softmax(input,dim=1).transpose(1,2).transpose(2,3).contiguous().view(n*h*w,c) |
|
|
56 |
input=input[target_bin>=0] |
|
|
57 |
input=input.view(-1,c) |
|
|
58 |
weight=weight.transpose(1,2).transpose(2,3).contiguous() |
|
|
59 |
weight=weight.view(n*h*w,1).repeat(1,c) |
|
|
60 |
''' |
|
|
61 |
mask=target>=0 |
|
|
62 |
target=target[mask] |
|
|
63 |
target_bin=np.zeros((n*h*w,c),np.float) |
|
|
64 |
for i,term in enumerate(target): |
|
|
65 |
target_bin[i,int(term)]=1 |
|
|
66 |
target_bin=torch.from_numpy(target_bin).float() |
|
|
67 |
target_bin=Variable(target_bin.cuda()) |
|
|
68 |
''' |
|
|
69 |
loss=F.binary_cross_entropy(input,target_bin,weight=weight,size_average=False) |
|
|
70 |
if size_average: |
|
|
71 |
loss/=(target_bin>=0).data.sum().float()/c |
|
|
72 |
return loss |
|
|
73 |
|
|
|
74 |
def bce2d_hed(input, target): |
|
|
75 |
""" |
|
|
76 |
Binary Cross Entropy 2-Dimension loss |
|
|
77 |
""" |
|
|
78 |
n, c, h, w = input.size() |
|
|
79 |
# assert(max(target) == 1) |
|
|
80 |
log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) |
|
|
81 |
target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1).float().cuda() |
|
|
82 |
target_trans = target_t.clone() |
|
|
83 |
pos_index = (target_t >0) |
|
|
84 |
neg_index = (target_t ==0) |
|
|
85 |
target_trans[pos_index] = 1 |
|
|
86 |
target_trans[neg_index] = 0 |
|
|
87 |
pos_index = pos_index.data.cpu().numpy().astype(bool) |
|
|
88 |
neg_index = neg_index.data.cpu().numpy().astype(bool) |
|
|
89 |
weight = torch.Tensor(log_p.size()).fill_(0) |
|
|
90 |
weight = weight.numpy() |
|
|
91 |
pos_num = pos_index.sum() |
|
|
92 |
neg_num = neg_index.sum() |
|
|
93 |
sum_num = pos_num + neg_num |
|
|
94 |
weight[pos_index] = neg_num*1.0 / sum_num |
|
|
95 |
weight[neg_index] = pos_num*1.0 / sum_num |
|
|
96 |
|
|
|
97 |
weight = torch.from_numpy(weight) |
|
|
98 |
weight = weight.cuda() |
|
|
99 |
loss = F.binary_cross_entropy(log_p, target_t, weight, size_average=True) |
|
|
100 |
return loss |
|
|
101 |
|
|
|
102 |
def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): |
|
|
103 |
|
|
|
104 |
batch_size = input.size()[0] |
|
|
105 |
|
|
|
106 |
def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): |
|
|
107 |
n, c, h, w = input.size() |
|
|
108 |
log_p = F.log_softmax(input, dim=1) |
|
|
109 |
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) |
|
|
110 |
log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] |
|
|
111 |
log_p = log_p.view(-1, c) |
|
|
112 |
|
|
|
113 |
mask = target >= 0 |
|
|
114 |
target = target[mask] |
|
|
115 |
loss = F.nll_loss(log_p, target, weight=weight, ignore_index=250, |
|
|
116 |
reduce=False, size_average=False) |
|
|
117 |
topk_loss, _ = loss.topk(K) |
|
|
118 |
reduced_topk_loss = topk_loss.sum() / K |
|
|
119 |
|
|
|
120 |
return reduced_topk_loss |
|
|
121 |
|
|
|
122 |
loss = 0.0 |
|
|
123 |
# Bootstrap from each image not entire batch |
|
|
124 |
for i in range(batch_size): |
|
|
125 |
loss += _bootstrap_xentropy_single(input=torch.unsqueeze(input[i], 0), |
|
|
126 |
target=torch.unsqueeze(target[i], 0), |
|
|
127 |
K=K, |
|
|
128 |
weight=weight, |
|
|
129 |
size_average=size_average) |
|
|
130 |
return loss / float(batch_size) |
|
|
131 |
|
|
|
132 |
# another implimentation for dice loss |
|
|
133 |
import torch |
|
|
134 |
from torch.autograd import Function, Variable |
|
|
135 |
class DiceCoeff(Function): |
|
|
136 |
"""Dice coeff for individual examples""" |
|
|
137 |
def forward(self, input, target): |
|
|
138 |
self.save_for_backward(input, target) |
|
|
139 |
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001 |
|
|
140 |
self.union = torch.sum(input) + torch.sum(target) + 0.0001 |
|
|
141 |
t = 2 * self.inter.float() / self.union.float() |
|
|
142 |
return t |
|
|
143 |
# This function has only a single output, so it gets only one gradient |
|
|
144 |
def backward(self, grad_output): |
|
|
145 |
input, target = self.saved_variables |
|
|
146 |
grad_input = grad_target = None |
|
|
147 |
if self.needs_input_grad[0]: |
|
|
148 |
grad_input = grad_output * 2 * (target * self.union + self.inter) \ |
|
|
149 |
/ self.union * self.union |
|
|
150 |
if self.needs_input_grad[1]: |
|
|
151 |
grad_target = None |
|
|
152 |
return grad_input, grad_target |
|
|
153 |
def dice_coeff(input, target): |
|
|
154 |
target_bin=Variable(torch.zeros(1,11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1).float()) |
|
|
155 |
"""Dice coeff for batches""" |
|
|
156 |
if input.is_cuda: |
|
|
157 |
s = torch.FloatTensor(1).cuda().zero_() |
|
|
158 |
else: |
|
|
159 |
s = torch.FloatTensor(1).zero_() |
|
|
160 |
for i, c in enumerate(zip(input, target_bin)): |
|
|
161 |
s = s + DiceCoeff().forward(c[0], c[1]) |
|
|
162 |
return s / (i + 1) |
|
|
163 |
|