--- a
+++ b/common/loss.py
@@ -0,0 +1,62 @@
+import torch.nn as nn
+import torch
+import numpy as np
+
+
+
+
+def _fast_hist(true, pred, num_classes):
+    pred = np.round(pred).astype(int)
+    true = np.round(true).astype(int)
+    mask = (true >= 0) & (true < num_classes)
+    hist = np.bincount(
+        num_classes * true[mask] + pred[mask],
+        minlength=num_classes ** 2,
+    ).reshape(num_classes, num_classes).astype(np.float)
+    return hist
+
+def jaccard_index(hist):
+    """Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
+    Args:
+        hist: confusion matrix.
+    Returns:
+        avg_jacc: the average per-class jaccard index.
+    """
+    A_inter_B = np.diag(hist)
+    A = np.sum(hist,axis=1)
+    B = np.sum(hist,axis=0)
+    jaccard = A_inter_B / (A + B - A_inter_B + 1e-6)
+    avg_jacc =np.nanmean(jaccard) #the mean of jaccard without NaNs
+    return avg_jacc, jaccard
+
+def dice_coef_metric(hist):
+    """Computes the dice coefficient).
+    Args:
+        hist: confusion matrix.
+     Returns:
+        avg_dice: the average per-class dice coefficient.
+    """
+    A_inter_B = np.diag(hist)
+    A = np.sum(hist,axis=1)
+    B = np.sum(hist,axis=0)
+    dsc = A_inter_B * 2 / (A + B + 1e-6)
+    avg_dsc=np.nanmean(dsc) #the mean of dsc without NaNs
+    return avg_dsc
+
+
+def dice_coef_loss(y_pred, y_true):
+      smooth=1.0
+      assert y_pred.size() == y_true.size()
+      intersection = (y_pred * y_true).sum()
+      dsc = (2. * intersection + smooth) / (
+          y_pred.sum() + y_true.sum() + smooth
+      )
+      return 1. - dsc
+
+
+def bce_dice_loss(y_pred, y_true):
+    dicescore = dice_coef_loss(y_pred, y_true)
+    bcescore = nn.BCELoss()
+    m = nn.Sigmoid()
+    bceloss = bcescore(m(y_pred), y_true)
+    return (bceloss + dicescore)