--- a +++ b/common/loss.py @@ -0,0 +1,62 @@ +import torch.nn as nn +import torch +import numpy as np + + + + +def _fast_hist(true, pred, num_classes): + pred = np.round(pred).astype(int) + true = np.round(true).astype(int) + mask = (true >= 0) & (true < num_classes) + hist = np.bincount( + num_classes * true[mask] + pred[mask], + minlength=num_classes ** 2, + ).reshape(num_classes, num_classes).astype(np.float) + return hist + +def jaccard_index(hist): + """Computes the Jaccard index, a.k.a the Intersection over Union (IoU). + Args: + hist: confusion matrix. + Returns: + avg_jacc: the average per-class jaccard index. + """ + A_inter_B = np.diag(hist) + A = np.sum(hist,axis=1) + B = np.sum(hist,axis=0) + jaccard = A_inter_B / (A + B - A_inter_B + 1e-6) + avg_jacc =np.nanmean(jaccard) #the mean of jaccard without NaNs + return avg_jacc, jaccard + +def dice_coef_metric(hist): + """Computes the dice coefficient). + Args: + hist: confusion matrix. + Returns: + avg_dice: the average per-class dice coefficient. + """ + A_inter_B = np.diag(hist) + A = np.sum(hist,axis=1) + B = np.sum(hist,axis=0) + dsc = A_inter_B * 2 / (A + B + 1e-6) + avg_dsc=np.nanmean(dsc) #the mean of dsc without NaNs + return avg_dsc + + +def dice_coef_loss(y_pred, y_true): + smooth=1.0 + assert y_pred.size() == y_true.size() + intersection = (y_pred * y_true).sum() + dsc = (2. * intersection + smooth) / ( + y_pred.sum() + y_true.sum() + smooth + ) + return 1. - dsc + + +def bce_dice_loss(y_pred, y_true): + dicescore = dice_coef_loss(y_pred, y_true) + bcescore = nn.BCELoss() + m = nn.Sigmoid() + bceloss = bcescore(m(y_pred), y_true) + return (bceloss + dicescore)