|
a |
|
b/summary.py |
|
|
1 |
''' |
|
|
2 |
https://github.com/akaraspt/deepsleepnet |
|
|
3 |
Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. |
|
|
4 |
''' |
|
|
5 |
|
|
|
6 |
#! /usr/bin/python |
|
|
7 |
# -*- coding: utf8 -*- |
|
|
8 |
|
|
|
9 |
import argparse |
|
|
10 |
import os |
|
|
11 |
import re |
|
|
12 |
import scipy.io as sio |
|
|
13 |
import numpy as np |
|
|
14 |
from sklearn.metrics import cohen_kappa_score |
|
|
15 |
from sklearn.metrics import confusion_matrix, f1_score |
|
|
16 |
|
|
|
17 |
W=0 |
|
|
18 |
N1=1 |
|
|
19 |
N2=2 |
|
|
20 |
N3=3 |
|
|
21 |
REM=4 |
|
|
22 |
classes= ['W','N1', 'N2','N3','REM'] |
|
|
23 |
n_classes = len(classes) |
|
|
24 |
|
|
|
25 |
def evaluate_metrics(cm): |
|
|
26 |
|
|
|
27 |
print ("Confusion matrix:") |
|
|
28 |
print (cm) |
|
|
29 |
|
|
|
30 |
cm = cm.astype(np.float32) |
|
|
31 |
FP = cm.sum(axis=0) - np.diag(cm) |
|
|
32 |
FN = cm.sum(axis=1) - np.diag(cm) |
|
|
33 |
TP = np.diag(cm) |
|
|
34 |
TN = cm.sum() - (FP + FN + TP) |
|
|
35 |
# https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal |
|
|
36 |
# Sensitivity, hit rate, recall, or true positive rate |
|
|
37 |
TPR = TP / (TP + FN) |
|
|
38 |
# Specificity or true negative rate |
|
|
39 |
TNR = TN / (TN + FP) |
|
|
40 |
# Precision or positive predictive value |
|
|
41 |
PPV = TP / (TP + FP) |
|
|
42 |
# Negative predictive value |
|
|
43 |
NPV = TN / (TN + FN) |
|
|
44 |
# Fall out or false positive rate |
|
|
45 |
FPR = FP / (FP + TN) |
|
|
46 |
# False negative rate |
|
|
47 |
FNR = FN / (TP + FN) |
|
|
48 |
# False discovery rate |
|
|
49 |
FDR = FP / (TP + FP) |
|
|
50 |
|
|
|
51 |
# Overall accuracy |
|
|
52 |
ACC = (TP + TN) / (TP + FP + FN + TN) |
|
|
53 |
# ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN)) |
|
|
54 |
ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average) |
|
|
55 |
|
|
|
56 |
F1 = (2 * PPV * TPR) / (PPV + TPR) |
|
|
57 |
F1_macro = np.mean(F1) |
|
|
58 |
|
|
|
59 |
print ("Sample: {}".format(int(np.sum(cm)))) |
|
|
60 |
for index_ in range(n_classes): |
|
|
61 |
print ("{}: {}".format(classes[index_], int(TP[index_] + FN[index_]))) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
return ACC_macro,ACC, F1_macro, F1, TPR, TNR, PPV |
|
|
65 |
|
|
|
66 |
def print_performance(cm,y_true=[],y_pred=[]): |
|
|
67 |
tp = np.diagonal(cm).astype(np.float) |
|
|
68 |
tpfp = np.sum(cm, axis=0).astype(np.float) # sum of each col |
|
|
69 |
tpfn = np.sum(cm, axis=1).astype(np.float) # sum of each row |
|
|
70 |
acc = np.sum(tp)/ np.sum(cm) |
|
|
71 |
precision = tp / tpfp |
|
|
72 |
recall = tp / tpfn |
|
|
73 |
f1 = (2 * precision * recall) / (precision + recall) |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
FP = cm.sum(axis=0).astype(np.float) - np.diag(cm) |
|
|
77 |
FN = cm.sum(axis=1).astype(np.float) - np.diag(cm) |
|
|
78 |
TP = np.diag(cm).astype(np.float) |
|
|
79 |
TN = cm.sum().astype(np.float) - (FP + FN + TP) |
|
|
80 |
specificity = TN / (TN + FP) #TNR |
|
|
81 |
|
|
|
82 |
mf1 = np.mean(f1) |
|
|
83 |
|
|
|
84 |
print ("Sample: {}".format(np.sum(cm))) |
|
|
85 |
print ("W: {}".format(tpfn[W])) |
|
|
86 |
print ("N1: {}".format(tpfn[N1])) |
|
|
87 |
print ("N2: {}".format(tpfn[N2])) |
|
|
88 |
print ("N3: {}".format(tpfn[N3])) |
|
|
89 |
print ("REM: {}".format(tpfn[REM])) |
|
|
90 |
print ("Confusion matrix:") |
|
|
91 |
print (cm) |
|
|
92 |
print ("Precision(PPV): {}".format(precision)) |
|
|
93 |
print ("Recall(Sensitivity): {}".format(recall)) |
|
|
94 |
print ("Specificity: {}".format(specificity)) |
|
|
95 |
print ("F1: {}".format(f1)) |
|
|
96 |
if (len(y_true)>0): |
|
|
97 |
print ("Overall accuracy: {}".format(np.mean(y_true == y_pred))) |
|
|
98 |
print ("Cohen's kappa score: {}".format(cohen_kappa_score(y_true, y_pred))) |
|
|
99 |
|
|
|
100 |
else: |
|
|
101 |
print ("Overall accuracy: {}".format(acc)) |
|
|
102 |
print ("Macro-F1 accuracy: {}".format(mf1)) |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
def perf_overall(data_dir): |
|
|
106 |
# Remove non-output files, and perform ascending sort |
|
|
107 |
allfiles = os.listdir(data_dir) |
|
|
108 |
outputfiles = [] |
|
|
109 |
for idx, f in enumerate(allfiles): |
|
|
110 |
if re.match("^output_.+\d+\.npz", f): |
|
|
111 |
outputfiles.append(os.path.join(data_dir, f)) |
|
|
112 |
outputfiles.sort() |
|
|
113 |
|
|
|
114 |
y_true = [] |
|
|
115 |
y_pred = [] |
|
|
116 |
for fpath in outputfiles: |
|
|
117 |
with np.load(fpath) as f: |
|
|
118 |
print(f["y_true"].shape) |
|
|
119 |
if len(f["y_true"].shape) == 1: |
|
|
120 |
if len(f["y_true"]) < 10: |
|
|
121 |
f_y_true = np.hstack(f["y_true"]) |
|
|
122 |
f_y_pred = np.hstack(f["y_pred"]) |
|
|
123 |
else: |
|
|
124 |
f_y_true = f["y_true"] |
|
|
125 |
f_y_pred = f["y_pred"] |
|
|
126 |
else: |
|
|
127 |
f_y_true = f["y_true"].flatten() |
|
|
128 |
f_y_pred = f["y_pred"].flatten() |
|
|
129 |
|
|
|
130 |
y_true.extend(f_y_true) |
|
|
131 |
y_pred.extend(f_y_pred) |
|
|
132 |
|
|
|
133 |
print ("File: {}".format(fpath)) |
|
|
134 |
cm = confusion_matrix(f_y_true, f_y_pred, labels=[0, 1, 2, 3, 4]) |
|
|
135 |
print_performance(cm) |
|
|
136 |
print (" ") |
|
|
137 |
|
|
|
138 |
y_true = np.asarray(y_true) |
|
|
139 |
y_pred = np.asarray(y_pred) |
|
|
140 |
sio.savemat('con_matrix_sleep.mat',{'y_true': y_true, 'y_pred': y_pred}) |
|
|
141 |
cm = confusion_matrix(y_true, y_pred,labels=range(n_classes)) |
|
|
142 |
acc = np.mean(y_true == y_pred) |
|
|
143 |
mf1 = f1_score(y_true, y_pred, average="macro") |
|
|
144 |
|
|
|
145 |
total = np.sum(cm, axis=1) |
|
|
146 |
|
|
|
147 |
print "Ours:" |
|
|
148 |
print_performance(cm,y_true,y_pred) |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
def main(): |
|
|
152 |
parser = argparse.ArgumentParser() |
|
|
153 |
parser.add_argument("--data_dir", type=str, default="outputs_2013/outputs_eeg_fpz_cz", |
|
|
154 |
help="Directory where to load prediction outputs") |
|
|
155 |
args = parser.parse_args() |
|
|
156 |
|
|
|
157 |
if args.data_dir is not None: |
|
|
158 |
perf_overall(data_dir=args.data_dir) |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
if __name__ == "__main__": |
|
|
162 |
main() |