Diff of /loss.py [000000] .. [9ff54e]

Switch to side-by-side view

--- a
+++ b/loss.py
@@ -0,0 +1,23 @@
+import sys
+from torch import nn
+import torch
+
+
+class DiceLoss(nn.Module):
+    """
+    Dice loss function class
+    """
+    def __init__(self, squared_denom=False):
+        super(DiceLoss, self).__init__()
+        self.smooth = sys.float_info.epsilon
+        self.squared_denom = squared_denom
+
+    def forward(self, x, target):
+        x = x.view(-1)
+        target = target.view(-1)
+        intersection = (x * target).sum()
+        numer = 2. * intersection + self.smooth
+        factor = 2 if self.squared_denom else 1
+        denom = x.pow(factor).sum() + target.pow(factor).sum() + self.smooth
+        dice_index = numer / denom
+        return 1 - dice_index