Diff of /src/predict.py [000000] .. [a378de]

Switch to side-by-side view

--- a
+++ b/src/predict.py
@@ -0,0 +1,106 @@
+"""
+The CINC data is provided by https://physionet.org/challenge/2017/ 
+"""
+from __future__ import division, print_function
+import numpy as np
+from config import get_config
+from utils import *
+from graph import *
+import os 
+
+def cincData(config):
+    if config.cinc_download:
+      cmd = "curl -O https://archive.physionet.org/challenge/2017/training2017.zip"
+      os.system(cmd)
+      os.system("unzip training2017.zip")
+    num = config.num
+    import csv
+    testlabel = []
+
+    with open('training2017/REFERENCE.csv') as csv_file:
+      csv_reader = csv.reader(csv_file, delimiter=',')
+      line_count = 0
+      for row in csv_reader:
+        testlabel.append([row[0],row[1]])
+        #print(row[0], row[1])
+        line_count += 1
+      print(f'Processed {line_count} lines.')
+    if num == None:
+      high = len(testlabel)-1
+      num = np.random.randint(1,high)
+    filename , label = testlabel[num-1]
+    filename = 'training2017/'+ filename + '.mat'
+    from scipy.io import loadmat
+    data = loadmat(filename)
+    print("The record of "+ filename)
+    if not config.upload:
+        data = data['val']
+        _, size = data.shape
+        data = data.reshape(size,)
+    else:
+        data = np.array(data)
+    return data, label
+
+def predict(data, label, peaks, config):
+    classesM = ['N','Ventricular','Paced','A','F','Noise']
+    predicted, result  = predictByPart(data, peaks)
+    print("The predicted", predicted)
+    sumPredict = sum(predicted[x][1] for x in range(len(predicted)))
+    avgPredict = sumPredict/len(predicted)
+    print("The average of the predict is:", avgPredict)
+    print("The most predicted label is {} with {:3.1f}% certainty".format(classesM[avgPredict.argmax()], 100*max(avgPredict)))
+    print("avgPredict", avgPredict)
+    sec_idx = avgPredict.argsort()[-2]
+    print("The second predicted label is {} with {:3.1f}% certainty".format(classesM[sec_idx], 100*avgPredict[sec_idx]))
+    print("The original label of the record is " + label)
+    if config.upload:
+      return predicted, classesM[avgPredict.argmax()], 100*max(avgPredict)
+
+def predictByPart(data, peaks):
+    classesM = ['N','Ventricular','Paced','A','F','Noise']#,'L','R','f','j','E','a','J','Q','e','S']
+    predicted = list()
+    result = ""
+    counter = [0]* len(classesM)
+    from keras.models import load_model
+    model = load_model(
+    'models/MLII-latest.keras', 
+    custom_objects={
+        'zeropad': zeropad,
+        'zeropad_output_shape': zeropad_output_shape
+      }
+    )
+    config = get_config() 
+    for i, peak in enumerate(peaks[3:-1]):
+      total_n =len(peaks)
+      start, end =  peak-config.input_size//2 , peak+config.input_size//2
+      prob = model.predict(data[:, start:end])
+      prob = prob[0]
+      ann = np.argmax(prob)
+      counter[ann]+=1
+      
+      if classesM[ann] != "N":
+        print("The {}/{}-record classified as {} with {:3.1f}% certainty".format(i,total_n,classesM[ann],100*prob[ann]))
+      result += "("+ classesM[ann] +":" + str(round(100*prob[ann],1)) + "%)"
+      predicted.append([classesM[ann],prob])
+      if classesM[ann] != 'N' and prob[ann] > 0.95:
+        import matplotlib.pyplot as plt
+        plt.plot(data[:, start:end][0,:,0],)
+        mkdir_recursive('results')
+        plt.savefig('results/hazard-'+classesM[ann]+'.png', format="png", dpi = 300)
+        plt.close()
+    result += "{}-N, {}-Venticular, {}-Paced, {}-A, {}-F, {}-Noise".format(counter[0], counter[1], counter[2], counter[3], counter[4], counter[5])
+    return predicted, result
+
+def main(config):
+  classesM= ['N','Ventricular','Paced','A','F', 'Noise']#,'L','R','f','j','E','a','J','Q','e','S']
+
+  if config.upload:
+    data = uploadedData(file)
+  else:
+    data, label = cincData(config)
+  data, peaks = preprocess(data, config)
+  return predict(data, label, peaks, config)
+
+if __name__=='__main__':
+  config = get_config()
+  main(config)