|
a |
|
b/src/predict.py |
|
|
1 |
""" |
|
|
2 |
The CINC data is provided by https://physionet.org/challenge/2017/ |
|
|
3 |
""" |
|
|
4 |
from __future__ import division, print_function |
|
|
5 |
import numpy as np |
|
|
6 |
from config import get_config |
|
|
7 |
from utils import * |
|
|
8 |
from graph import * |
|
|
9 |
import os |
|
|
10 |
|
|
|
11 |
def cincData(config): |
|
|
12 |
if config.cinc_download: |
|
|
13 |
cmd = "curl -O https://archive.physionet.org/challenge/2017/training2017.zip" |
|
|
14 |
os.system(cmd) |
|
|
15 |
os.system("unzip training2017.zip") |
|
|
16 |
num = config.num |
|
|
17 |
import csv |
|
|
18 |
testlabel = [] |
|
|
19 |
|
|
|
20 |
with open('training2017/REFERENCE.csv') as csv_file: |
|
|
21 |
csv_reader = csv.reader(csv_file, delimiter=',') |
|
|
22 |
line_count = 0 |
|
|
23 |
for row in csv_reader: |
|
|
24 |
testlabel.append([row[0],row[1]]) |
|
|
25 |
#print(row[0], row[1]) |
|
|
26 |
line_count += 1 |
|
|
27 |
print(f'Processed {line_count} lines.') |
|
|
28 |
if num == None: |
|
|
29 |
high = len(testlabel)-1 |
|
|
30 |
num = np.random.randint(1,high) |
|
|
31 |
filename , label = testlabel[num-1] |
|
|
32 |
filename = 'training2017/'+ filename + '.mat' |
|
|
33 |
from scipy.io import loadmat |
|
|
34 |
data = loadmat(filename) |
|
|
35 |
print("The record of "+ filename) |
|
|
36 |
if not config.upload: |
|
|
37 |
data = data['val'] |
|
|
38 |
_, size = data.shape |
|
|
39 |
data = data.reshape(size,) |
|
|
40 |
else: |
|
|
41 |
data = np.array(data) |
|
|
42 |
return data, label |
|
|
43 |
|
|
|
44 |
def predict(data, label, peaks, config): |
|
|
45 |
classesM = ['N','Ventricular','Paced','A','F','Noise'] |
|
|
46 |
predicted, result = predictByPart(data, peaks) |
|
|
47 |
print("The predicted", predicted) |
|
|
48 |
sumPredict = sum(predicted[x][1] for x in range(len(predicted))) |
|
|
49 |
avgPredict = sumPredict/len(predicted) |
|
|
50 |
print("The average of the predict is:", avgPredict) |
|
|
51 |
print("The most predicted label is {} with {:3.1f}% certainty".format(classesM[avgPredict.argmax()], 100*max(avgPredict))) |
|
|
52 |
print("avgPredict", avgPredict) |
|
|
53 |
sec_idx = avgPredict.argsort()[-2] |
|
|
54 |
print("The second predicted label is {} with {:3.1f}% certainty".format(classesM[sec_idx], 100*avgPredict[sec_idx])) |
|
|
55 |
print("The original label of the record is " + label) |
|
|
56 |
if config.upload: |
|
|
57 |
return predicted, classesM[avgPredict.argmax()], 100*max(avgPredict) |
|
|
58 |
|
|
|
59 |
def predictByPart(data, peaks): |
|
|
60 |
classesM = ['N','Ventricular','Paced','A','F','Noise']#,'L','R','f','j','E','a','J','Q','e','S'] |
|
|
61 |
predicted = list() |
|
|
62 |
result = "" |
|
|
63 |
counter = [0]* len(classesM) |
|
|
64 |
from keras.models import load_model |
|
|
65 |
model = load_model( |
|
|
66 |
'models/MLII-latest.keras', |
|
|
67 |
custom_objects={ |
|
|
68 |
'zeropad': zeropad, |
|
|
69 |
'zeropad_output_shape': zeropad_output_shape |
|
|
70 |
} |
|
|
71 |
) |
|
|
72 |
config = get_config() |
|
|
73 |
for i, peak in enumerate(peaks[3:-1]): |
|
|
74 |
total_n =len(peaks) |
|
|
75 |
start, end = peak-config.input_size//2 , peak+config.input_size//2 |
|
|
76 |
prob = model.predict(data[:, start:end]) |
|
|
77 |
prob = prob[0] |
|
|
78 |
ann = np.argmax(prob) |
|
|
79 |
counter[ann]+=1 |
|
|
80 |
|
|
|
81 |
if classesM[ann] != "N": |
|
|
82 |
print("The {}/{}-record classified as {} with {:3.1f}% certainty".format(i,total_n,classesM[ann],100*prob[ann])) |
|
|
83 |
result += "("+ classesM[ann] +":" + str(round(100*prob[ann],1)) + "%)" |
|
|
84 |
predicted.append([classesM[ann],prob]) |
|
|
85 |
if classesM[ann] != 'N' and prob[ann] > 0.95: |
|
|
86 |
import matplotlib.pyplot as plt |
|
|
87 |
plt.plot(data[:, start:end][0,:,0],) |
|
|
88 |
mkdir_recursive('results') |
|
|
89 |
plt.savefig('results/hazard-'+classesM[ann]+'.png', format="png", dpi = 300) |
|
|
90 |
plt.close() |
|
|
91 |
result += "{}-N, {}-Venticular, {}-Paced, {}-A, {}-F, {}-Noise".format(counter[0], counter[1], counter[2], counter[3], counter[4], counter[5]) |
|
|
92 |
return predicted, result |
|
|
93 |
|
|
|
94 |
def main(config): |
|
|
95 |
classesM= ['N','Ventricular','Paced','A','F', 'Noise']#,'L','R','f','j','E','a','J','Q','e','S'] |
|
|
96 |
|
|
|
97 |
if config.upload: |
|
|
98 |
data = uploadedData(file) |
|
|
99 |
else: |
|
|
100 |
data, label = cincData(config) |
|
|
101 |
data, peaks = preprocess(data, config) |
|
|
102 |
return predict(data, label, peaks, config) |
|
|
103 |
|
|
|
104 |
if __name__=='__main__': |
|
|
105 |
config = get_config() |
|
|
106 |
main(config) |