|
a |
|
b/dsc.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from torch.autograd import Function |
|
|
5 |
|
|
|
6 |
class DiceCoeff(Function): |
|
|
7 |
"""Dice coeff for individual examples""" |
|
|
8 |
|
|
|
9 |
def forward(self, input, target): |
|
|
10 |
|
|
|
11 |
self.save_for_backward(input, target) |
|
|
12 |
eps = 0.0001 |
|
|
13 |
self.inter = torch.dot(input.contiguous().view(-1), target.contiguous().view(-1)) |
|
|
14 |
self.union = torch.sum(input) + torch.sum(target) + eps |
|
|
15 |
|
|
|
16 |
t = (2 * self.inter.float() + eps) / self.union.float() |
|
|
17 |
return t |
|
|
18 |
|
|
|
19 |
# This function has only a single output, so it gets only one gradient |
|
|
20 |
def backward(self, grad_output): |
|
|
21 |
|
|
|
22 |
input, target = self.saved_variables |
|
|
23 |
grad_input = grad_target = None |
|
|
24 |
|
|
|
25 |
if self.needs_input_grad[0]: |
|
|
26 |
grad_input = grad_output * 2 * (target * self.union - self.inter) \ |
|
|
27 |
/ (self.union * self.union) |
|
|
28 |
if self.needs_input_grad[1]: |
|
|
29 |
grad_target = None |
|
|
30 |
|
|
|
31 |
return grad_input, grad_target |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def dice_coeff(input, target): |
|
|
35 |
"""Dice coeff for batches""" |
|
|
36 |
if input.is_cuda: |
|
|
37 |
s = torch.FloatTensor(1).cuda().zero_() |
|
|
38 |
else: |
|
|
39 |
s = torch.FloatTensor(1).zero_() |
|
|
40 |
|
|
|
41 |
for i, c in enumerate(zip(input, target)): |
|
|
42 |
s = s + DiceCoeff().forward(c[0], c[1]) |
|
|
43 |
|
|
|
44 |
return s / (i + 1) |
|
|
45 |
|
|
|
46 |
def DICESEN_loss(input, target): |
|
|
47 |
smooth = 0.00000001 |
|
|
48 |
y_true_f = input.view(-1) |
|
|
49 |
y_pred_f = target.view(-1) |
|
|
50 |
intersection = torch.sum(torch.mul(y_true_f,y_pred_f)) |
|
|
51 |
dice= (2. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + torch.mul(y_pred_f,y_pred_f).sum() + smooth) |
|
|
52 |
sen = (1. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + smooth) |
|
|
53 |
return 2-dice-sen |
|
|
54 |
|
|
|
55 |
class DiceSensitivityLoss(nn.Module): |
|
|
56 |
def __init__(self, n_classes): |
|
|
57 |
self.n_classes = n_classes |
|
|
58 |
super(DiceSensitivityLoss, self).__init__() |
|
|
59 |
|
|
|
60 |
def forward(self, inputs, targets, smooth = 1.): |
|
|
61 |
|
|
|
62 |
if self.n_classes == 1: |
|
|
63 |
inputs = torch.sigmoid(inputs) |
|
|
64 |
else: |
|
|
65 |
inputs = F.softmax(inputs, dim=1) |
|
|
66 |
|
|
|
67 |
y_true_f = inputs.view(-1) |
|
|
68 |
y_pred_f = targets.view(-1) |
|
|
69 |
|
|
|
70 |
intersection = (y_true_f * y_pred_f).sum() |
|
|
71 |
|
|
|
72 |
dice= (2. * intersection + smooth) / (y_pred_f.sum() + y_true_f.sum() + smooth) |
|
|
73 |
|
|
|
74 |
sen = (1. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + smooth) |
|
|
75 |
|
|
|
76 |
return 2 - dice-sen |