Diff of /loss.py [000000] .. [8ff467]

Switch to unified view

a b/loss.py
1
import torch
2
import numpy as np
3
import torch.nn as nn
4
import torch.nn.functional as F
5
from torch.autograd import Function, Variable
6
7
def cross_entropy2d(input, target, weight=None, size_average=True):
8
    
9
    n, c, h, w = input.size()
10
    nt, ct, ht, wt = target.size()
11
    '''
12
    # Handle inconsistent size between input and target
13
    if h > ht and w > wt: # upsample labels
14
        target = target.unsequeeze(1)
15
        target = F.upsample(target, size=(h, w), mode='nearest')
16
        target = target.sequeeze(1)
17
    elif h < ht and w < wt: # upsample images
18
        input = F.upsample(input, size=(ht, wt), mode='bilinear')
19
    elif h != ht and w != wt:
20
        raise Exception("Only support upsampling")
21
    '''
22
    log_p = F.log_softmax(input, dim=1)
23
    log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
24
    log_p = log_p[target.contiguous().view(-1, 1).repeat(1, c) >= 0]
25
    log_p = log_p.view(-1, c)
26
27
    mask = target >= 0
28
    target = target[mask]
29
    loss = F.nll_loss(log_p, target, ignore_index=250,
30
                      weight=weight, size_average=False)
31
    if size_average:
32
        loss /= mask.data.sum().float()
33
    return loss
34
35
def loss_ce_t(input,target):
36
    #input=F.sigmoid(input)
37
    target_bin=Variable(torch.zeros(1,11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1))
38
    return F.binary_cross_entropy_with_logits(input,target_bin)
39
40
def dice_loss(input, target):
41
    target_bin=Variable(torch.zeros(target.shape[0],11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1))
42
    smooth = 1.
43
    iflat = input.view(-1)
44
    tflat = target_bin.view(-1)
45
    intersection = (iflat * tflat).sum()
46
    return 1 - ((2. * intersection + smooth) /
47
            (iflat.sum() + tflat.sum() + smooth))
48
49
def weighted_loss(input,target,weight,size_average=True):
50
    n,c,h,w=input.size()
51
    target_bin=Variable(torch.zeros(n,c,h,w).cuda()).scatter_(1,target,1)
52
    target_bin=target_bin.transpose(1,2).transpose(2,3).contiguous().view(n*h*w,c).float()
53
    
54
    # NHWC
55
    input=F.softmax(input,dim=1).transpose(1,2).transpose(2,3).contiguous().view(n*h*w,c)
56
    input=input[target_bin>=0]
57
    input=input.view(-1,c)
58
    weight=weight.transpose(1,2).transpose(2,3).contiguous()
59
    weight=weight.view(n*h*w,1).repeat(1,c)
60
    '''
61
    mask=target>=0
62
    target=target[mask]
63
    target_bin=np.zeros((n*h*w,c),np.float)
64
    for i,term in enumerate(target):
65
        target_bin[i,int(term)]=1
66
    target_bin=torch.from_numpy(target_bin).float()
67
    target_bin=Variable(target_bin.cuda())
68
    '''
69
    loss=F.binary_cross_entropy(input,target_bin,weight=weight,size_average=False)
70
    if size_average:
71
        loss/=(target_bin>=0).data.sum().float()/c
72
    return loss
73
74
def bce2d_hed(input, target):
75
    """
76
    Binary Cross Entropy 2-Dimension loss
77
    """
78
    n, c, h, w = input.size()
79
    # assert(max(target) == 1)
80
    log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
81
    target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1).float().cuda()
82
    target_trans = target_t.clone()
83
    pos_index = (target_t >0)
84
    neg_index = (target_t ==0)
85
    target_trans[pos_index] = 1
86
    target_trans[neg_index] = 0
87
    pos_index = pos_index.data.cpu().numpy().astype(bool)
88
    neg_index = neg_index.data.cpu().numpy().astype(bool)
89
    weight = torch.Tensor(log_p.size()).fill_(0)
90
    weight = weight.numpy()
91
    pos_num = pos_index.sum()
92
    neg_num = neg_index.sum()
93
    sum_num = pos_num + neg_num
94
    weight[pos_index] = neg_num*1.0 / sum_num
95
    weight[neg_index] = pos_num*1.0 / sum_num
96
97
    weight = torch.from_numpy(weight)
98
    weight = weight.cuda()
99
    loss = F.binary_cross_entropy(log_p, target_t, weight, size_average=True)
100
    return loss
101
102
def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True):
103
104
    batch_size = input.size()[0]
105
106
    def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True):
107
        n, c, h, w = input.size()
108
        log_p = F.log_softmax(input, dim=1)
109
        log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
110
        log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
111
        log_p = log_p.view(-1, c)
112
113
        mask = target >= 0
114
        target = target[mask]
115
        loss = F.nll_loss(log_p, target, weight=weight, ignore_index=250,
116
                          reduce=False, size_average=False)
117
        topk_loss, _ = loss.topk(K)
118
        reduced_topk_loss = topk_loss.sum() / K
119
120
        return reduced_topk_loss
121
122
    loss = 0.0
123
    # Bootstrap from each image not entire batch
124
    for i in range(batch_size):
125
        loss += _bootstrap_xentropy_single(input=torch.unsqueeze(input[i], 0),
126
                                           target=torch.unsqueeze(target[i], 0),
127
                                           K=K,
128
                                           weight=weight,
129
                                           size_average=size_average)
130
    return loss / float(batch_size)
131
132
# another implimentation for dice loss
133
import torch
134
from torch.autograd import Function, Variable
135
class DiceCoeff(Function):
136
    """Dice coeff for individual examples"""
137
    def forward(self, input, target):
138
        self.save_for_backward(input, target)
139
        self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
140
        self.union = torch.sum(input) + torch.sum(target) + 0.0001
141
        t = 2 * self.inter.float() / self.union.float()
142
        return t
143
    # This function has only a single output, so it gets only one gradient
144
    def backward(self, grad_output):
145
        input, target = self.saved_variables
146
        grad_input = grad_target = None
147
        if self.needs_input_grad[0]:
148
            grad_input = grad_output * 2 * (target * self.union + self.inter) \
149
                         / self.union * self.union
150
        if self.needs_input_grad[1]:
151
            grad_target = None
152
        return grad_input, grad_target
153
def dice_coeff(input, target):
154
    target_bin=Variable(torch.zeros(1,11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1).float())
155
    """Dice coeff for batches"""
156
    if input.is_cuda:
157
        s = torch.FloatTensor(1).cuda().zero_()
158
    else:
159
        s = torch.FloatTensor(1).zero_()
160
    for i, c in enumerate(zip(input, target_bin)):
161
        s = s + DiceCoeff().forward(c[0], c[1])
162
    return s / (i + 1)
163