a | b/loss.py | ||
---|---|---|---|
1 | import torch.nn as nn |
||
2 | |||
3 | |||
4 | class DiceLoss(nn.Module): |
||
5 | |||
6 | def __init__(self): |
||
7 | super(DiceLoss, self).__init__() |
||
8 | self.smooth = 1.0 |
||
9 | |||
10 | def forward(self, y_pred, y_true): |
||
11 | assert y_pred.size() == y_true.size() |
||
12 | y_pred = y_pred[:, 0].contiguous().view(-1) |
||
13 | y_true = y_true[:, 0].contiguous().view(-1) |
||
14 | intersection = (y_pred * y_true).sum() |
||
15 | dsc = (2. * intersection + self.smooth) / ( |
||
16 | y_pred.sum() + y_true.sum() + self.smooth |
||
17 | ) |
||
18 | return 1. - dsc |