|
a |
|
b/semseg/loss.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
def dice(outputs, labels): |
|
|
5 |
eps = 1e-5 |
|
|
6 |
outputs, labels = outputs.float(), labels.float() |
|
|
7 |
outputs, labels = outputs.flatten(), labels.flatten() |
|
|
8 |
intersect = torch.dot(outputs, labels) |
|
|
9 |
union = torch.add(torch.sum(outputs), torch.sum(labels)) |
|
|
10 |
dice_coeff = (2 * intersect + eps) / (union + eps) |
|
|
11 |
dice_loss = - dice_coeff + 1 |
|
|
12 |
return dice_loss |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def one_hot_encode(label, num_classes): |
|
|
16 |
""" Torch One Hot Encode |
|
|
17 |
:param label: Tensor of shape BxHxW or BxDxHxW |
|
|
18 |
:param num_classes: K classes |
|
|
19 |
:return: label_ohe, Tensor of shape BxKxHxW or BxKxDxHxW |
|
|
20 |
""" |
|
|
21 |
assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape) |
|
|
22 |
label_ohe = None |
|
|
23 |
if len(label.shape) == 3: |
|
|
24 |
label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2])) |
|
|
25 |
elif len(label.shape) == 4: |
|
|
26 |
label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3])) |
|
|
27 |
for batch_idx, batch_el_label in enumerate(label): |
|
|
28 |
for cls in range(num_classes): |
|
|
29 |
label_ohe[batch_idx, cls] = (batch_el_label == cls) |
|
|
30 |
label_ohe = label_ohe.long() |
|
|
31 |
return label_ohe |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def dice_n_classes(outputs, labels, do_one_hot=False, get_list=False, device=None): |
|
|
35 |
""" |
|
|
36 |
Computes the Multi-class classification Dice Coefficient. |
|
|
37 |
It is computed as the average Dice for all classes, each time |
|
|
38 |
considering a class versus all the others. |
|
|
39 |
Class 0 (background) is not considered in the average. |
|
|
40 |
:param outputs: probabilities outputs of the CNN. Shape: [BxKxHxW] |
|
|
41 |
:param labels: ground truth Shape: [BxKxHxW] |
|
|
42 |
:param do_one_hot: set to True if ground truth has shape [BxHxW] |
|
|
43 |
:param get_list: set to True if you want the list of dices per class instead of average |
|
|
44 |
:param device: CUDA device on which compute the dice |
|
|
45 |
:return: Multiclass classification Dice Loss |
|
|
46 |
""" |
|
|
47 |
num_classes = outputs.shape[1] |
|
|
48 |
if do_one_hot: |
|
|
49 |
labels = one_hot_encode(labels, num_classes) |
|
|
50 |
labels = labels.cuda(device=device) |
|
|
51 |
|
|
|
52 |
dices = list() |
|
|
53 |
for cls in range(1, num_classes): |
|
|
54 |
outputs_ = outputs[:, cls].unsqueeze(dim=1) |
|
|
55 |
labels_ = labels[:, cls].unsqueeze(dim=1) |
|
|
56 |
dice_ = dice(outputs_, labels_) |
|
|
57 |
dices.append(dice_) |
|
|
58 |
if get_list: |
|
|
59 |
return dices |
|
|
60 |
else: |
|
|
61 |
return sum(dices) / (num_classes-1) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def get_multi_dice_loss(outputs, labels, device=None): |
|
|
65 |
labels = labels[:, 0] |
|
|
66 |
return dice_n_classes(outputs, labels, do_one_hot=True, get_list=False, device=device) |