--- a
+++ b/bme1312/evaluation.py
@@ -0,0 +1,117 @@
+
+import torch
+from sklearn.metrics import precision_recall_fscore_support,accuracy_score,jaccard_score
+
+
+# SR : Segmentation Result
+# GT : Ground Truth
+
+def get_accuracy(SR, GT, threshold=0.5):
+    # SR = SR > threshold
+    # GT = GT == torch.max(GT)
+    # corr = torch.sum(SR == GT)
+    # tensor_size = SR.size(0) * SR.size(1) * SR.size(2) * SR.size(3)
+    # acc = float(corr) / float(tensor_size)
+    acc=accuracy_score(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)
+
+    return acc
+
+
+def get_sensitivity(SR, GT, threshold=0.5):
+    # Sensitivity == Recall
+    # SR = SR > threshold
+    # GT = GT == torch.max(GT)
+
+    # # TP : True Positive
+    # # FN : False Negative
+    # TP = ((SR == 1) & (GT == 1))
+    # FN = ((SR == 0) & (GT == 1))
+
+    # print("torch.sum(TP)",torch.sum(TP))
+    # print("torch.sum(TP + FN)",torch.sum(TP + FN))
+
+    # SE = float(torch.sum(TP)) / (float(torch.sum(TP + FN)) + 1e-6)
+    SE=precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[1][1]
+    return SE
+
+
+def get_specificity(SR, GT, threshold=0.5):
+    # SR = SR > threshold
+    # GT = GT == torch.max(GT)
+
+    # # TN : True Negative
+    # # FP : False Positive
+    # TN = ((SR == 0) & (GT == 0))
+    # FP = ((SR == 1) & (GT == 0))
+
+    # SP = float(torch.sum(TN)) / (float(torch.sum(TN + FP)) + 1e-6)
+
+    # print("------------------specificity---------------")
+    # print("TN:",torch.sum(TN))
+    # print("FP:",torch.sum(FP))
+    # print("SP:",SP)
+    SP=precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[1][0]
+    return SP
+
+
+def get_precision(SR, GT, threshold=0.5):
+    # SR = SR > threshold
+    # GT = GT == torch.max(GT)
+
+    # # TP : True Positive
+    # # FP : False Positive
+    # TP = ((SR == 1) & (GT == 1))
+    # FP = ((SR == 1) & (GT == 0))
+
+    # PC = float(torch.sum(TP)) / (float(torch.sum(TP + FP)) + 1e-6)
+
+    # print("------------------precision---------------")
+    # print("TP:", torch.sum(TP))
+    # print("FP:", torch.sum(FP))
+    # print("PC:", PC)
+    PC=precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[0][1]
+    return PC
+
+
+def get_F1(SR, GT, threshold=0.5):
+    # Sensitivity == Recall
+    # SE = get_sensitivity(SR, GT, threshold=threshold)
+    # PC = get_precision(SR, GT, threshold=threshold)
+
+    # F1 = 2 * SE * PC / (SE + PC + 1e-6)
+    F1 = precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[2][1]
+    return F1
+
+
+def get_JS(SR, GT, threshold=0.5):
+    # JS : Jaccard similarity
+    # SR = SR > threshold
+    # GT = GT == torch.max(GT)
+
+    # Inter = torch.sum(((SR == 1) & (GT == 1)))
+    # Union = torch.sum((SR + GT) >= 1)
+
+    # JS = float(Inter) / (float(Union) + 1e-6)
+    JS = jaccard_score(GT.cpu().flatten().numpy(),(SR.cpu().flatten().numpy()>threshold))
+    return JS
+
+
+# def get_DC(SR, GT, threshold=0.5):
+#     # DC : Dice Coefficient
+#     # SR = SR > threshold
+#     # GT = GT == torch.max(GT)
+
+
+#     # Inter = torch.sum(((SR==1) & (GT==1)))
+#     # DC = float(2 * Inter) / (float(torch.sum(SR) + torch.sum(GT)) + 1e-6)
+#     DC = precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[2][1]
+#     return DC
+
+def get_DC(SR, GT, threshold=0.5):
+    # DC : Dice Coefficient
+    # SR = SR > threshold
+    # GT = GT == torch.max(GT)
+    # Inter = torch.sum(((SR==1) & (GT==1)))
+    # DC = float(2 * Inter) / (float(torch.sum(SR) + torch.sum(GT)) + 1e-6)
+    DC = precision_recall_fscore_support(GT.cpu().detach().numpy().flatten(), SR.cpu().detach().numpy().flatten() > threshold)[2][1]
+    return DC
\ No newline at end of file