[72db80]: / loss / diceloss.py

Download this file

15 lines (12 with data), 453 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import torch
class diceloss(torch.nn.Module):
def init(self):
super(diceloss, self).init()
def forward(self, pred, target):
smooth = 1
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(iflat * iflat)
B_sum = torch.sum(tflat * tflat)
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth))