Diff of /loss/dice.py [000000] .. [f77492]

Switch to unified view

a b/loss/dice.py
1
import torch
2
from .utils import *
3
4
5
def dice_loss(input, target):
6
    """
7
    2d dice loss
8
    :param input: predict tensor
9
    :param target: target tensor
10
    :return: scalar loss value
11
    """
12
    input = input > 0.5
13
    target = target == torch.max(target)
14
15
    input = to_float_and_cuda(input)
16
    target = to_float_and_cuda(target)
17
18
    num = input * target
19
    num = torch.sum(num, dim=2)  # 在dim维度上求和 维度减1 如果想要保留原始维度 使用keepdim=True
20
    num = torch.sum(num, dim=2)
21
22
    den1 = input * input
23
    den1 = torch.sum(den1, dim=2)
24
    den1 = torch.sum(den1, dim=2)
25
26
    den2 = target * target
27
    den2 = torch.sum(den2, dim=2)
28
    den2 = torch.sum(den2, dim=2)
29
30
    dice = 2 * (num / (den1 + den2)) + 1e-6
31
    dice_total = 1 - 1 * torch.sum(dice) / dice.size(0)  # divide by batchsize
32
33
    return dice_total