a b/bin/confusion_matrix_rates.r
1
options(stringsAsFactors = FALSE)
2
# library("clusterSim")
3
4
list.of.packages <- c("PRROC", "e1071")
5
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
6
if(length(new.packages)) install.packages(new.packages)
7
8
library("e1071")
9
library("PRROC")
10
source("./utils.r")
11
12
# Confusion matrix rates
13
confusion_matrix_rates <- function (actual_labels, predicted_values, keyword)
14
{
15
16
    fg_test <- predicted_values[actual_labels==1]
17
    bg_test <- predicted_values[actual_labels==0]
18
19
    pr_curve_test <- pr.curve(scores.class0 = fg_test, scores.class1 = bg_test, curve = F)
20
    # plot(pr_curve_test)
21
    # print(pr_curve_test)
22
    prc_auc <- pr_curve_test$auc.integral
23
    cat("\nPR AUC (integral) \t", prc_auc, "\n", sep="")    
24
    # cat("PRC AUC (Davis & Goadrich) ", pr_curve_test$auc.davis.goadrichl, "\n", sep="")
25
26
    roc_curve_test <- roc.curve(scores.class0 = fg_test, scores.class1 = bg_test, curve = F)
27
    # plot(pr_curve_test)
28
    # print(roc_curve_test)
29
    roc_auc <- roc_curve_test$auc
30
    cat("ROC AUC \t\t", roc_auc, "\n\n", sep="")
31
32
    predicted_values_binary <- as.numeric(predicted_values)
33
    predicted_values_binary[predicted_values_binary>=threshold]=1
34
    predicted_values_binary[predicted_values_binary<threshold]=0
35
36
    actual <- actual_labels
37
    predicted <- predicted_values_binary
38
  
39
  TP <- sum(actual == 1 & predicted == 1)
40
  TN <- sum(actual == 0 & predicted == 0)
41
  FP <- sum(actual == 0 & predicted == 1)
42
  FN <- sum(actual == 1 & predicted == 0)
43
  
44
  
45
  cat("\nTOTAL:\n\n")
46
  cat(" FN = ", (FN), " / ", (FN+TP), "\t (truth == 1) & (prediction < threshold)\n");
47
  cat(" TP = ", (TP), " / ", (FN+TP),"\t (truth == 1) & (prediction >= threshold)\n\n");
48
    
49
50
  cat(" FP = ", (FP), " / ", (FP+TN), "\t (truth == 0) & (prediction >= threshold)\n");
51
  cat(" TN = ", (TN), " / ", (FP+TN), "\t (truth == 0) & (prediction < threshold)\n\n");
52
  
53
  sum1 <- TP+FP; sum2 <-TP+FN ; sum3 <-TN+FP ; sum4 <- TN+FN;
54
  denom <- as.double(sum1)*sum2*sum3*sum4 # as.double to avoid overflow error on large products
55
  if (any(sum1==0, sum2==0, sum3==0, sum4==0)) {
56
    denom <- 1
57
  }
58
  mcc <- ((TP*TN)-(FP*FN)) / sqrt(denom)
59
  
60
  f1_score <- 2*TP / (2*TP + FP + FN)
61
  accuracy <- (TN+TP) / (TN + TP + FP + FN)
62
  recall <- TP / (TP + FN)
63
  specificity <- TN / (TN + FP)
64
65
  cat("\n\n",keyword,"\t MCC \t F1_score \t accuracy \t TP_rate \t TN_rate \t PR AUC \t ROC AUC\n")
66
  cat(keyword,"      ", signed_dec_two(mcc), " \t ", dec_two(f1_score), " \t ", dec_two(accuracy), " \t ", dec_two(recall), " \t ", dec_two(specificity),  "\t\t ", dec_two(prc_auc), "\t\t", dec_two(roc_auc),  "\n\n")
67
 
68
 
69
#   cat("\nMCC = ", dec_two(mcc), "\n\n", sep="")
70
#   
71
#   cat("f1_score = ", dec_two(f1_score), "\n", sep="")
72
#   cat("accuracy = ", dec_two(accuracy), "\n", sep="")
73
#   
74
#   cat("\n")
75
#   cat("true positive rate = recall = ", dec_two(recall), "\n", sep="")
76
#   cat("true negative rate = specificity = ", dec_two(specificity), "\n", sep="")
77
#   cat("\n")
78
79
}
80
81
# Matthews correlation coefficient
82
mcc <- function (actual, predicted)
83
{
84
  # Compute the Matthews correlation coefficient (MCC) score
85
  # Jeff Hebert 9/1/2016
86
  # Geoffrey Anderson 10/14/2016 
87
  # Added zero denominator handling.
88
  # Avoided overflow error on large-ish products in denominator.
89
  #
90
  # actual = vector of true outcomes, 1 = Positive, 0 = Negative
91
  # predicted = vector of predicted outcomes, 1 = Positive, 0 = Negative
92
  # function returns MCC
93
  
94
  TP <- sum(actual == 1 & predicted == 1)
95
  TN <- sum(actual == 0 & predicted == 0)
96
  FP <- sum(actual == 0 & predicted == 1)
97
  FN <- sum(actual == 1 & predicted == 0)
98
  #TP;TN;FP;FN # for debugging
99
  sum1 <- TP+FP; sum2 <-TP+FN ; sum3 <-TN+FP ; sum4 <- TN+FN;
100
  denom <- as.double(sum1)*sum2*sum3*sum4 # as.double to avoid overflow error on large products
101
  if (any(sum1==0, sum2==0, sum3==0, sum4==0)) {
102
    denom <- 1
103
  }
104
  mcc <- ((TP*TN)-(FP*FN)) / sqrt(denom)
105
  
106
  cat("\nMCC = ", dec_two(mcc), "\n\n", sep="")
107
  
108
  return(mcc)
109
}