Diff of /loss.py [000000] .. [9ff54e]

Switch to unified view

a b/loss.py
1
import sys
2
from torch import nn
3
import torch
4
5
6
class DiceLoss(nn.Module):
7
    """
8
    Dice loss function class
9
    """
10
    def __init__(self, squared_denom=False):
11
        super(DiceLoss, self).__init__()
12
        self.smooth = sys.float_info.epsilon
13
        self.squared_denom = squared_denom
14
15
    def forward(self, x, target):
16
        x = x.view(-1)
17
        target = target.view(-1)
18
        intersection = (x * target).sum()
19
        numer = 2. * intersection + self.smooth
20
        factor = 2 if self.squared_denom else 1
21
        denom = x.pow(factor).sum() + target.pow(factor).sum() + self.smooth
22
        dice_index = numer / denom
23
        return 1 - dice_index