|
a |
|
b/common/loss.py |
|
|
1 |
import torch.nn as nn |
|
|
2 |
import torch |
|
|
3 |
import numpy as np |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def _fast_hist(true, pred, num_classes): |
|
|
9 |
pred = np.round(pred).astype(int) |
|
|
10 |
true = np.round(true).astype(int) |
|
|
11 |
mask = (true >= 0) & (true < num_classes) |
|
|
12 |
hist = np.bincount( |
|
|
13 |
num_classes * true[mask] + pred[mask], |
|
|
14 |
minlength=num_classes ** 2, |
|
|
15 |
).reshape(num_classes, num_classes).astype(np.float) |
|
|
16 |
return hist |
|
|
17 |
|
|
|
18 |
def jaccard_index(hist): |
|
|
19 |
"""Computes the Jaccard index, a.k.a the Intersection over Union (IoU). |
|
|
20 |
Args: |
|
|
21 |
hist: confusion matrix. |
|
|
22 |
Returns: |
|
|
23 |
avg_jacc: the average per-class jaccard index. |
|
|
24 |
""" |
|
|
25 |
A_inter_B = np.diag(hist) |
|
|
26 |
A = np.sum(hist,axis=1) |
|
|
27 |
B = np.sum(hist,axis=0) |
|
|
28 |
jaccard = A_inter_B / (A + B - A_inter_B + 1e-6) |
|
|
29 |
avg_jacc =np.nanmean(jaccard) #the mean of jaccard without NaNs |
|
|
30 |
return avg_jacc, jaccard |
|
|
31 |
|
|
|
32 |
def dice_coef_metric(hist): |
|
|
33 |
"""Computes the dice coefficient). |
|
|
34 |
Args: |
|
|
35 |
hist: confusion matrix. |
|
|
36 |
Returns: |
|
|
37 |
avg_dice: the average per-class dice coefficient. |
|
|
38 |
""" |
|
|
39 |
A_inter_B = np.diag(hist) |
|
|
40 |
A = np.sum(hist,axis=1) |
|
|
41 |
B = np.sum(hist,axis=0) |
|
|
42 |
dsc = A_inter_B * 2 / (A + B + 1e-6) |
|
|
43 |
avg_dsc=np.nanmean(dsc) #the mean of dsc without NaNs |
|
|
44 |
return avg_dsc |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def dice_coef_loss(y_pred, y_true): |
|
|
48 |
smooth=1.0 |
|
|
49 |
assert y_pred.size() == y_true.size() |
|
|
50 |
intersection = (y_pred * y_true).sum() |
|
|
51 |
dsc = (2. * intersection + smooth) / ( |
|
|
52 |
y_pred.sum() + y_true.sum() + smooth |
|
|
53 |
) |
|
|
54 |
return 1. - dsc |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def bce_dice_loss(y_pred, y_true): |
|
|
58 |
dicescore = dice_coef_loss(y_pred, y_true) |
|
|
59 |
bcescore = nn.BCELoss() |
|
|
60 |
m = nn.Sigmoid() |
|
|
61 |
bceloss = bcescore(m(y_pred), y_true) |
|
|
62 |
return (bceloss + dicescore) |