[9ff54e]: / loss.py

Download this file

24 lines (20 with data), 664 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import sys
from torch import nn
import torch
class DiceLoss(nn.Module):
"""
Dice loss function class
"""
def __init__(self, squared_denom=False):
super(DiceLoss, self).__init__()
self.smooth = sys.float_info.epsilon
self.squared_denom = squared_denom
def forward(self, x, target):
x = x.view(-1)
target = target.view(-1)
intersection = (x * target).sum()
numer = 2. * intersection + self.smooth
factor = 2 if self.squared_denom else 1
denom = x.pow(factor).sum() + target.pow(factor).sum() + self.smooth
dice_index = numer / denom
return 1 - dice_index