Diff of /loss.py [000000] .. [9cc651]

Switch to unified view

a b/loss.py
1
import torch.nn as nn
2
3
4
class DiceLoss(nn.Module):
5
6
    def __init__(self):
7
        super(DiceLoss, self).__init__()
8
        self.smooth = 1.0
9
10
    def forward(self, y_pred, y_true):
11
        assert y_pred.size() == y_true.size()
12
        y_pred = y_pred[:, 0].contiguous().view(-1)
13
        y_true = y_true[:, 0].contiguous().view(-1)
14
        intersection = (y_pred * y_true).sum()
15
        dsc = (2. * intersection + self.smooth) / (
16
            y_pred.sum() + y_true.sum() + self.smooth
17
        )
18
        return 1. - dsc