Diff of /loss.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/loss.py
@@ -0,0 +1,52 @@
+import torch.nn.functional as F
+import torch.nn as nn
+import torch
+from einops import rearrange
+
+
+def cal_dice(output, target, eps=1e-3):
+    output = torch.argmax(output,dim=1)
+    inter = torch.sum(output * target) + eps
+    union = torch.sum(output) + torch.sum(target) + eps * 2
+    dice = 2 * inter / union
+    return dice
+
+
+class Loss(nn.Module):
+    def __init__(self, n_classes, alpha=0.5):
+        "dice_loss_plus_cetr_weighted"
+        super(Loss, self).__init__()
+        self.n_classes = n_classes
+        self.alpha = alpha
+
+    def forward(self, input, target):
+        smooth = 0.01
+        # print(torch.unique(target))
+
+        input1 = F.softmax(input, dim=1)
+        target1 = F.one_hot(target,self.n_classes)
+        input1 = rearrange(input1,'b n h w s -> b n (h w s)')
+        target1 = rearrange(target1,'b h w s n -> b n (h w s)')
+        # 只取前景
+        input1 = input1[:, 1:, :]
+        target1 = target1[:, 1:, :].float()
+
+        # 以batch为单位计算dice_loss
+        inter = torch.sum(input1 * target1)
+        union = torch.sum(input1) + torch.sum(target1) + smooth
+        dice = 2.0 * inter / union
+
+        loss = F.cross_entropy(input,target)
+
+        total_loss = (1 - self.alpha) * loss + (1 - dice) * self.alpha
+
+        return total_loss
+
+
+if __name__ == '__main__':
+    torch.manual_seed(3)
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    losser = Loss(n_classes=2).to(device)
+    x = torch.randn((4, 2, 16, 16, 16)).to(device)
+    y = torch.randint(0, 2, (4, 16, 16, 16)).to(device)
+    print(losser(x, y))