Diff of /src/predict.py [000000] .. [a378de]

Switch to unified view

a b/src/predict.py
1
"""
2
The CINC data is provided by https://physionet.org/challenge/2017/ 
3
"""
4
from __future__ import division, print_function
5
import numpy as np
6
from config import get_config
7
from utils import *
8
from graph import *
9
import os 
10
11
def cincData(config):
12
    if config.cinc_download:
13
      cmd = "curl -O https://archive.physionet.org/challenge/2017/training2017.zip"
14
      os.system(cmd)
15
      os.system("unzip training2017.zip")
16
    num = config.num
17
    import csv
18
    testlabel = []
19
20
    with open('training2017/REFERENCE.csv') as csv_file:
21
      csv_reader = csv.reader(csv_file, delimiter=',')
22
      line_count = 0
23
      for row in csv_reader:
24
        testlabel.append([row[0],row[1]])
25
        #print(row[0], row[1])
26
        line_count += 1
27
      print(f'Processed {line_count} lines.')
28
    if num == None:
29
      high = len(testlabel)-1
30
      num = np.random.randint(1,high)
31
    filename , label = testlabel[num-1]
32
    filename = 'training2017/'+ filename + '.mat'
33
    from scipy.io import loadmat
34
    data = loadmat(filename)
35
    print("The record of "+ filename)
36
    if not config.upload:
37
        data = data['val']
38
        _, size = data.shape
39
        data = data.reshape(size,)
40
    else:
41
        data = np.array(data)
42
    return data, label
43
44
def predict(data, label, peaks, config):
45
    classesM = ['N','Ventricular','Paced','A','F','Noise']
46
    predicted, result  = predictByPart(data, peaks)
47
    print("The predicted", predicted)
48
    sumPredict = sum(predicted[x][1] for x in range(len(predicted)))
49
    avgPredict = sumPredict/len(predicted)
50
    print("The average of the predict is:", avgPredict)
51
    print("The most predicted label is {} with {:3.1f}% certainty".format(classesM[avgPredict.argmax()], 100*max(avgPredict)))
52
    print("avgPredict", avgPredict)
53
    sec_idx = avgPredict.argsort()[-2]
54
    print("The second predicted label is {} with {:3.1f}% certainty".format(classesM[sec_idx], 100*avgPredict[sec_idx]))
55
    print("The original label of the record is " + label)
56
    if config.upload:
57
      return predicted, classesM[avgPredict.argmax()], 100*max(avgPredict)
58
59
def predictByPart(data, peaks):
60
    classesM = ['N','Ventricular','Paced','A','F','Noise']#,'L','R','f','j','E','a','J','Q','e','S']
61
    predicted = list()
62
    result = ""
63
    counter = [0]* len(classesM)
64
    from keras.models import load_model
65
    model = load_model(
66
    'models/MLII-latest.keras', 
67
    custom_objects={
68
        'zeropad': zeropad,
69
        'zeropad_output_shape': zeropad_output_shape
70
      }
71
    )
72
    config = get_config() 
73
    for i, peak in enumerate(peaks[3:-1]):
74
      total_n =len(peaks)
75
      start, end =  peak-config.input_size//2 , peak+config.input_size//2
76
      prob = model.predict(data[:, start:end])
77
      prob = prob[0]
78
      ann = np.argmax(prob)
79
      counter[ann]+=1
80
      
81
      if classesM[ann] != "N":
82
        print("The {}/{}-record classified as {} with {:3.1f}% certainty".format(i,total_n,classesM[ann],100*prob[ann]))
83
      result += "("+ classesM[ann] +":" + str(round(100*prob[ann],1)) + "%)"
84
      predicted.append([classesM[ann],prob])
85
      if classesM[ann] != 'N' and prob[ann] > 0.95:
86
        import matplotlib.pyplot as plt
87
        plt.plot(data[:, start:end][0,:,0],)
88
        mkdir_recursive('results')
89
        plt.savefig('results/hazard-'+classesM[ann]+'.png', format="png", dpi = 300)
90
        plt.close()
91
    result += "{}-N, {}-Venticular, {}-Paced, {}-A, {}-F, {}-Noise".format(counter[0], counter[1], counter[2], counter[3], counter[4], counter[5])
92
    return predicted, result
93
94
def main(config):
95
  classesM= ['N','Ventricular','Paced','A','F', 'Noise']#,'L','R','f','j','E','a','J','Q','e','S']
96
97
  if config.upload:
98
    data = uploadedData(file)
99
  else:
100
    data, label = cincData(config)
101
  data, peaks = preprocess(data, config)
102
  return predict(data, label, peaks, config)
103
104
if __name__=='__main__':
105
  config = get_config()
106
  main(config)