Diff of /semseg/loss.py [000000] .. [cc8b8f]

Switch to unified view

a b/semseg/loss.py
1
import torch
2
3
4
def dice(outputs, labels):
5
    eps = 1e-5
6
    outputs, labels = outputs.float(), labels.float()
7
    outputs, labels = outputs.flatten(), labels.flatten()
8
    intersect = torch.dot(outputs, labels)
9
    union = torch.add(torch.sum(outputs), torch.sum(labels))
10
    dice_coeff = (2 * intersect + eps) / (union + eps)
11
    dice_loss = - dice_coeff + 1
12
    return dice_loss
13
14
15
def one_hot_encode(label, num_classes):
16
    """ Torch One Hot Encode
17
    :param label: Tensor of shape BxHxW or BxDxHxW
18
    :param num_classes: K classes
19
    :return: label_ohe, Tensor of shape BxKxHxW or BxKxDxHxW
20
    """
21
    assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
22
    label_ohe = None
23
    if len(label.shape) == 3:
24
        label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
25
    elif len(label.shape) == 4:
26
        label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))
27
    for batch_idx, batch_el_label in enumerate(label):
28
        for cls in range(num_classes):
29
            label_ohe[batch_idx, cls] = (batch_el_label == cls)
30
    label_ohe = label_ohe.long()
31
    return label_ohe
32
33
34
def dice_n_classes(outputs, labels, do_one_hot=False, get_list=False, device=None):
35
    """
36
    Computes the Multi-class classification Dice Coefficient.
37
    It is computed as the average Dice for all classes, each time
38
    considering a class versus all the others.
39
    Class 0 (background) is not considered in the average.
40
    :param outputs: probabilities outputs of the CNN. Shape: [BxKxHxW]
41
    :param labels:  ground truth                      Shape: [BxKxHxW]
42
    :param do_one_hot: set to True if ground truth has shape [BxHxW]
43
    :param get_list:   set to True if you want the list of dices per class instead of average
44
    :param device: CUDA device on which compute the dice
45
    :return: Multiclass classification Dice Loss
46
    """
47
    num_classes = outputs.shape[1]
48
    if do_one_hot:
49
        labels = one_hot_encode(labels, num_classes)
50
        labels = labels.cuda(device=device)
51
52
    dices = list()
53
    for cls in range(1, num_classes):
54
        outputs_ = outputs[:, cls].unsqueeze(dim=1)
55
        labels_  = labels[:, cls].unsqueeze(dim=1)
56
        dice_ = dice(outputs_, labels_)
57
        dices.append(dice_)
58
    if get_list:
59
        return dices
60
    else:
61
        return sum(dices) / (num_classes-1)
62
63
64
def get_multi_dice_loss(outputs, labels, device=None):
65
    labels = labels[:, 0]
66
    return dice_n_classes(outputs, labels, do_one_hot=True, get_list=False, device=device)