Diff of /loss/dice.py [000000] .. [f77492]

Switch to side-by-side view

--- a
+++ b/loss/dice.py
@@ -0,0 +1,33 @@
+import torch
+from .utils import *
+
+
+def dice_loss(input, target):
+    """
+    2d dice loss
+    :param input: predict tensor
+    :param target: target tensor
+    :return: scalar loss value
+    """
+    input = input > 0.5
+    target = target == torch.max(target)
+
+    input = to_float_and_cuda(input)
+    target = to_float_and_cuda(target)
+
+    num = input * target
+    num = torch.sum(num, dim=2)  # 在dim维度上求和 维度减1 如果想要保留原始维度 使用keepdim=True
+    num = torch.sum(num, dim=2)
+
+    den1 = input * input
+    den1 = torch.sum(den1, dim=2)
+    den1 = torch.sum(den1, dim=2)
+
+    den2 = target * target
+    den2 = torch.sum(den2, dim=2)
+    den2 = torch.sum(den2, dim=2)
+
+    dice = 2 * (num / (den1 + den2)) + 1e-6
+    dice_total = 1 - 1 * torch.sum(dice) / dice.size(0)  # divide by batchsize
+
+    return dice_total