|
a |
|
b/loss.py |
|
|
1 |
import sys |
|
|
2 |
from torch import nn |
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class DiceLoss(nn.Module): |
|
|
7 |
""" |
|
|
8 |
Dice loss function class |
|
|
9 |
""" |
|
|
10 |
def __init__(self, squared_denom=False): |
|
|
11 |
super(DiceLoss, self).__init__() |
|
|
12 |
self.smooth = sys.float_info.epsilon |
|
|
13 |
self.squared_denom = squared_denom |
|
|
14 |
|
|
|
15 |
def forward(self, x, target): |
|
|
16 |
x = x.view(-1) |
|
|
17 |
target = target.view(-1) |
|
|
18 |
intersection = (x * target).sum() |
|
|
19 |
numer = 2. * intersection + self.smooth |
|
|
20 |
factor = 2 if self.squared_denom else 1 |
|
|
21 |
denom = x.pow(factor).sum() + target.pow(factor).sum() + self.smooth |
|
|
22 |
dice_index = numer / denom |
|
|
23 |
return 1 - dice_index |