a b/tensorflow/dnn_mitdb.py
1
""" 
2
Author: Mondejar Guerra
3
VARPA
4
University of A Coruna
5
April 2017
6
7
Description: Train and evaluate mitdb with interpatient split (train/test)
8
Uses DNN clasifier 
9
"""
10
11
import numpy as np
12
import matplotlib.pyplot as plt
13
import os
14
import csv
15
import pickle
16
import numpy as np
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import collections
20
tf.logging.set_verbosity(tf.logging.INFO)
21
22
def load_data(output_path, window_size, compute_RR_interval_feature, compute_wavelets):
23
  extension = '_' + str(window_size)
24
  if compute_wavelets:
25
      extension = extension + '_' + 'wv'
26
  if compute_RR_interval_feature:
27
      extension = extension + '_' + 'RR'
28
  extension = extension + '.csv'
29
30
  # Load training and eval data
31
  train_data = np.loadtxt(output_path + 'train_data' + extension, delimiter=",", dtype=float)
32
  train_labels =  np.loadtxt(output_path + 'train_label' + extension, delimiter=",",  dtype=np.int32)
33
  eval_data = np.loadtxt(output_path + 'eval_data' + extension, delimiter=",", dtype=float)
34
  eval_labels = np.loadtxt(output_path + 'eval_label' + extension, delimiter=",",  dtype=np.int32)
35
36
  return (train_data, train_labels, eval_data, eval_labels)
37
38
def main():
39
  window_size = 160
40
  compute_RR_interval_feature = True
41
  compute_wavelets = True
42
  dataset = '/home/mondejar/dataset/ECG/mitdb/'
43
  output_path = dataset + 'm_learning/'
44
45
  # 0 Load Data
46
  train_data, train_labels, eval_data, eval_labels = load_data(output_path, window_size, compute_RR_interval_feature, compute_wavelets)
47
48
  # 1 TODO Preprocess data? norm? if RR interval, last 4 features are pre, post, local and global RR
49
50
  # Apply some norm? convolution? another approach?
51
  
52
  # [0,33] wave
53
  normalize = True
54
  if normalize:
55
    feature_size = len(train_data[0])
56
    if compute_RR_interval_feature:
57
      feature_size = feature_size - 4
58
59
    max_wav = np.amax(np.vstack((train_data[:, 0:feature_size], eval_data[:, 0:feature_size])))
60
    min_wav = np.amin(np.vstack((train_data[:, 0:feature_size], eval_data[:, 0:feature_size])))
61
    
62
    train_data[:, 0:feature_size] = ((train_data[:,0:feature_size] - min_wav) / (max_wav - min_wav))
63
64
    eval_data[:, 0:feature_size] = ((eval_data[:,0:feature_size] - min_wav) / (max_wav - min_wav))
65
    #Norm last part feature: RR interval 
66
    if compute_RR_interval_feature:
67
68
      max_rr = np.amax(np.vstack((train_data[:, feature_size:], eval_data[:, feature_size:])))
69
      min_rr = np.amin(np.vstack((train_data[:, feature_size:], eval_data[:, feature_size:])))
70
71
      train_data[:, feature_size:] = ((train_data[:, feature_size:] - min_rr) / (max_rr - min_rr))
72
      eval_data[:,  feature_size:] = ((eval_data[:, feature_size:] - min_rr) / (max_rr - min_rr))
73
74
    # [34,38] RR interval
75
76
  # 2 Create model 
77
78
  # Specify that all features have real-value data
79
  feature_columns = [tf.contrib.layers.real_valued_column("", dimension=len(train_data[0]))]
80
81
  # Build 3 layer DNN with 10, 20, 10 units respectively.
82
83
  mitdb_classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
84
                                              hidden_units=[10, 20, 10],
85
                                              n_classes=5,
86
                                              model_dir="/tmp/mitdb")
87
88
  # Fit model.
89
  # Define the training inputs
90
  def get_train_inputs():
91
    x = tf.constant(train_data)
92
    y = tf.constant(train_labels)
93
94
    return x, y
95
    
96
  mitdb_classifier.fit(input_fn=get_train_inputs, steps=2000)
97
98
  # Evaluate accuracy. 
99
  def get_test_inputs():
100
    x = tf.constant(eval_data)
101
    y = tf.constant(eval_labels)
102
    return x, y
103
104
  accuracy_score = mitdb_classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"]
105
  print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
106
107
  def get_eval_data():
108
    return np.array(eval_data, dtype=np.float32)
109
110
  predictions = list(mitdb_classifier.predict(input_fn=get_eval_data))
111
112
  # Compute the matrix confussion
113
  confusion_matrix = np.zeros((5,5), dtype='int')
114
  for p in range(0, len(predictions), 1):
115
      confusion_matrix[predictions[p]][eval_labels[p]] = confusion_matrix[predictions[p]][eval_labels[p]] + 1
116
  
117
  print confusion_matrix
118
119
if __name__ == "__main__":
120
  main()