--- a +++ b/bme1312/evaluation.py @@ -0,0 +1,117 @@ + +import torch +from sklearn.metrics import precision_recall_fscore_support,accuracy_score,jaccard_score + + +# SR : Segmentation Result +# GT : Ground Truth + +def get_accuracy(SR, GT, threshold=0.5): + # SR = SR > threshold + # GT = GT == torch.max(GT) + # corr = torch.sum(SR == GT) + # tensor_size = SR.size(0) * SR.size(1) * SR.size(2) * SR.size(3) + # acc = float(corr) / float(tensor_size) + acc=accuracy_score(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold) + + return acc + + +def get_sensitivity(SR, GT, threshold=0.5): + # Sensitivity == Recall + # SR = SR > threshold + # GT = GT == torch.max(GT) + + # # TP : True Positive + # # FN : False Negative + # TP = ((SR == 1) & (GT == 1)) + # FN = ((SR == 0) & (GT == 1)) + + # print("torch.sum(TP)",torch.sum(TP)) + # print("torch.sum(TP + FN)",torch.sum(TP + FN)) + + # SE = float(torch.sum(TP)) / (float(torch.sum(TP + FN)) + 1e-6) + SE=precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[1][1] + return SE + + +def get_specificity(SR, GT, threshold=0.5): + # SR = SR > threshold + # GT = GT == torch.max(GT) + + # # TN : True Negative + # # FP : False Positive + # TN = ((SR == 0) & (GT == 0)) + # FP = ((SR == 1) & (GT == 0)) + + # SP = float(torch.sum(TN)) / (float(torch.sum(TN + FP)) + 1e-6) + + # print("------------------specificity---------------") + # print("TN:",torch.sum(TN)) + # print("FP:",torch.sum(FP)) + # print("SP:",SP) + SP=precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[1][0] + return SP + + +def get_precision(SR, GT, threshold=0.5): + # SR = SR > threshold + # GT = GT == torch.max(GT) + + # # TP : True Positive + # # FP : False Positive + # TP = ((SR == 1) & (GT == 1)) + # FP = ((SR == 1) & (GT == 0)) + + # PC = float(torch.sum(TP)) / (float(torch.sum(TP + FP)) + 1e-6) + + # print("------------------precision---------------") + # print("TP:", torch.sum(TP)) + # print("FP:", torch.sum(FP)) + # print("PC:", PC) + PC=precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[0][1] + return PC + + +def get_F1(SR, GT, threshold=0.5): + # Sensitivity == Recall + # SE = get_sensitivity(SR, GT, threshold=threshold) + # PC = get_precision(SR, GT, threshold=threshold) + + # F1 = 2 * SE * PC / (SE + PC + 1e-6) + F1 = precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[2][1] + return F1 + + +def get_JS(SR, GT, threshold=0.5): + # JS : Jaccard similarity + # SR = SR > threshold + # GT = GT == torch.max(GT) + + # Inter = torch.sum(((SR == 1) & (GT == 1))) + # Union = torch.sum((SR + GT) >= 1) + + # JS = float(Inter) / (float(Union) + 1e-6) + JS = jaccard_score(GT.cpu().flatten().numpy(),(SR.cpu().flatten().numpy()>threshold)) + return JS + + +# def get_DC(SR, GT, threshold=0.5): +# # DC : Dice Coefficient +# # SR = SR > threshold +# # GT = GT == torch.max(GT) + + +# # Inter = torch.sum(((SR==1) & (GT==1))) +# # DC = float(2 * Inter) / (float(torch.sum(SR) + torch.sum(GT)) + 1e-6) +# DC = precision_recall_fscore_support(GT.cpu().flatten().numpy(),SR.cpu().flatten().numpy()>threshold)[2][1] +# return DC + +def get_DC(SR, GT, threshold=0.5): + # DC : Dice Coefficient + # SR = SR > threshold + # GT = GT == torch.max(GT) + # Inter = torch.sum(((SR==1) & (GT==1))) + # DC = float(2 * Inter) / (float(torch.sum(SR) + torch.sum(GT)) + 1e-6) + DC = precision_recall_fscore_support(GT.cpu().detach().numpy().flatten(), SR.cpu().detach().numpy().flatten() > threshold)[2][1] + return DC \ No newline at end of file