--- a +++ b/loss.py @@ -0,0 +1,18 @@ +import torch.nn as nn + + +class DiceLoss(nn.Module): + + def __init__(self): + super(DiceLoss, self).__init__() + self.smooth = 1.0 + + def forward(self, y_pred, y_true): + assert y_pred.size() == y_true.size() + y_pred = y_pred[:, 0].contiguous().view(-1) + y_true = y_true[:, 0].contiguous().view(-1) + intersection = (y_pred * y_true).sum() + dsc = (2. * intersection + self.smooth) / ( + y_pred.sum() + y_true.sum() + self.smooth + ) + return 1. - dsc