Diff of /loss.py [000000] .. [9cc651]

Switch to side-by-side view

--- a
+++ b/loss.py
@@ -0,0 +1,18 @@
+import torch.nn as nn
+
+
+class DiceLoss(nn.Module):
+
+    def __init__(self):
+        super(DiceLoss, self).__init__()
+        self.smooth = 1.0
+
+    def forward(self, y_pred, y_true):
+        assert y_pred.size() == y_true.size()
+        y_pred = y_pred[:, 0].contiguous().view(-1)
+        y_true = y_true[:, 0].contiguous().view(-1)
+        intersection = (y_pred * y_true).sum()
+        dsc = (2. * intersection + self.smooth) / (
+            y_pred.sum() + y_true.sum() + self.smooth
+        )
+        return 1. - dsc