Diff of /common/loss.py [000000] .. [f804b3]

Switch to unified view

a b/common/loss.py
1
import torch.nn as nn
2
import torch
3
import numpy as np
4
5
6
7
8
def _fast_hist(true, pred, num_classes):
9
    pred = np.round(pred).astype(int)
10
    true = np.round(true).astype(int)
11
    mask = (true >= 0) & (true < num_classes)
12
    hist = np.bincount(
13
        num_classes * true[mask] + pred[mask],
14
        minlength=num_classes ** 2,
15
    ).reshape(num_classes, num_classes).astype(np.float)
16
    return hist
17
18
def jaccard_index(hist):
19
    """Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
20
    Args:
21
        hist: confusion matrix.
22
    Returns:
23
        avg_jacc: the average per-class jaccard index.
24
    """
25
    A_inter_B = np.diag(hist)
26
    A = np.sum(hist,axis=1)
27
    B = np.sum(hist,axis=0)
28
    jaccard = A_inter_B / (A + B - A_inter_B + 1e-6)
29
    avg_jacc =np.nanmean(jaccard) #the mean of jaccard without NaNs
30
    return avg_jacc, jaccard
31
32
def dice_coef_metric(hist):
33
    """Computes the dice coefficient).
34
    Args:
35
        hist: confusion matrix.
36
     Returns:
37
        avg_dice: the average per-class dice coefficient.
38
    """
39
    A_inter_B = np.diag(hist)
40
    A = np.sum(hist,axis=1)
41
    B = np.sum(hist,axis=0)
42
    dsc = A_inter_B * 2 / (A + B + 1e-6)
43
    avg_dsc=np.nanmean(dsc) #the mean of dsc without NaNs
44
    return avg_dsc
45
46
47
def dice_coef_loss(y_pred, y_true):
48
      smooth=1.0
49
      assert y_pred.size() == y_true.size()
50
      intersection = (y_pred * y_true).sum()
51
      dsc = (2. * intersection + smooth) / (
52
          y_pred.sum() + y_true.sum() + smooth
53
      )
54
      return 1. - dsc
55
56
57
def bce_dice_loss(y_pred, y_true):
58
    dicescore = dice_coef_loss(y_pred, y_true)
59
    bcescore = nn.BCELoss()
60
    m = nn.Sigmoid()
61
    bceloss = bcescore(m(y_pred), y_true)
62
    return (bceloss + dicescore)