|
a |
|
b/loss/dice.py |
|
|
1 |
import torch |
|
|
2 |
from .utils import * |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def dice_loss(input, target): |
|
|
6 |
""" |
|
|
7 |
2d dice loss |
|
|
8 |
:param input: predict tensor |
|
|
9 |
:param target: target tensor |
|
|
10 |
:return: scalar loss value |
|
|
11 |
""" |
|
|
12 |
input = input > 0.5 |
|
|
13 |
target = target == torch.max(target) |
|
|
14 |
|
|
|
15 |
input = to_float_and_cuda(input) |
|
|
16 |
target = to_float_and_cuda(target) |
|
|
17 |
|
|
|
18 |
num = input * target |
|
|
19 |
num = torch.sum(num, dim=2) # 在dim维度上求和 维度减1 如果想要保留原始维度 使用keepdim=True |
|
|
20 |
num = torch.sum(num, dim=2) |
|
|
21 |
|
|
|
22 |
den1 = input * input |
|
|
23 |
den1 = torch.sum(den1, dim=2) |
|
|
24 |
den1 = torch.sum(den1, dim=2) |
|
|
25 |
|
|
|
26 |
den2 = target * target |
|
|
27 |
den2 = torch.sum(den2, dim=2) |
|
|
28 |
den2 = torch.sum(den2, dim=2) |
|
|
29 |
|
|
|
30 |
dice = 2 * (num / (den1 + den2)) + 1e-6 |
|
|
31 |
dice_total = 1 - 1 * torch.sum(dice) / dice.size(0) # divide by batchsize |
|
|
32 |
|
|
|
33 |
return dice_total |