[9cc651]: / loss.py

Download this file

19 lines (14 with data), 520 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
self.smooth = 1.0
def forward(self, y_pred, y_true):
assert y_pred.size() == y_true.size()
y_pred = y_pred[:, 0].contiguous().view(-1)
y_true = y_true[:, 0].contiguous().view(-1)
intersection = (y_pred * y_true).sum()
dsc = (2. * intersection + self.smooth) / (
y_pred.sum() + y_true.sum() + self.smooth
)
return 1. - dsc