Diff of /src/utils.py [000000] .. [a378de]

Switch to unified view

a b/src/utils.py
1
from __future__ import division, print_function
2
from keras.callbacks import LearningRateScheduler
3
import matplotlib.pyplot as plt
4
import numpy as np
5
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, f1_score, classification_report
6
import os
7
import h5py
8
9
def mkdir_recursive(path):
10
  if path == "":
11
    return
12
  sub_path = os.path.dirname(path)
13
  if not os.path.exists(sub_path):
14
    mkdir_recursive(sub_path)
15
  if not os.path.exists(path):
16
    print("Creating directory " + path)
17
    os.mkdir(path)
18
19
def loaddata(input_size, feature):
20
    mkdir_recursive('dataset')
21
    print("Loading training data...")
22
    with h5py.File('dataset/train.keras', 'r') as f:
23
        trainData = {key: f[key][...] for key in f.keys()}
24
        
25
    print("Loading training labels...")
26
    with h5py.File('dataset/trainlabel.keras', 'r') as f:
27
        testlabelData = {key: f[key][...] for key in f.keys()}
28
        
29
    print("Available features in training data:", list(trainData.keys()))
30
    print("Available features in label data:", list(testlabelData.keys()))
31
        
32
    X = np.float32(trainData[feature])
33
    y = np.float32(testlabelData[feature])
34
    print("Training shapes before shuffle - X:", X.shape, "y:", y.shape)
35
    print("Any NaN in X:", np.any(np.isnan(X)), "y:", np.any(np.isnan(y)))
36
    
37
    att = np.concatenate((X,y), axis=1)
38
    np.random.shuffle(att)
39
    X, y = att[:,:input_size], att[:, input_size:]
40
    print("Training shapes after shuffle - X:", X.shape, "y:", y.shape)
41
    print("Any NaN after shuffle - X:", np.any(np.isnan(X)), "y:", np.any(np.isnan(y)))
42
    
43
    print("Loading validation data...")
44
    with h5py.File('dataset/test.keras', 'r') as f:
45
        valData = {key: f[key][...] for key in f.keys()}
46
        
47
    print("Loading validation labels...")
48
    with h5py.File('dataset/testlabel.keras', 'r') as f:
49
        vallabelData = {key: f[key][...] for key in f.keys()}
50
        
51
    Xval = np.float32(valData[feature])
52
    yval = np.float32(vallabelData[feature])
53
    print("Validation shapes - Xval:", Xval.shape, "yval:", yval.shape)
54
    print("Any NaN in validation - Xval:", np.any(np.isnan(Xval)), "yval:", np.any(np.isnan(yval)))
55
    
56
    return (X, y, Xval, yval)
57
58
class LearningRateSchedulerPerBatch(LearningRateScheduler):
59
    """ code from https://towardsdatascience.com/resuming-a-training-process-with-keras-3e93152ee11a
60
    Callback class to modify the default learning rate scheduler to operate each batch"""
61
    def __init__(self, schedule, verbose=0):
62
        super(LearningRateSchedulerPerBatch, self).__init__(schedule, verbose)
63
        self.count = 0  # Global batch index (the regular batch argument refers to the batch index within the epoch)
64
65
    def on_epoch_begin(self, epoch, logs=None):
66
        pass
67
68
    def on_epoch_end(self, epoch, logs=None):
69
        pass
70
71
    def on_batch_begin(self, batch, logs=None):
72
        super(LearningRateSchedulerPerBatch, self).on_epoch_begin(self.count, logs)
73
74
    def on_batch_end(self, batch, logs=None):
75
        super(LearningRateSchedulerPerBatch, self).on_epoch_end(self.count, logs)
76
        self.count += 1
77
78
79
def plot_confusion_matrix(y_true, y_pred, classes, feature,
80
                          normalize=False,
81
                          title=None,
82
                          cmap=plt.cm.Blues):
83
    """Modification from code at https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html"""
84
    if not title:
85
        if normalize:
86
            title = 'Normalized confusion matrix'
87
        else:
88
            title = 'Confusion matrix, without normalization'
89
90
    cm = confusion_matrix(y_true, y_pred)
91
    #classes = classes[unique_labels(y_true, y_pred)]
92
93
    if normalize:
94
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
95
        print("Normalized confusion matrix")
96
    else:
97
        print('Confusion matrix, without normalization')
98
99
    print(cm)
100
    fig, ax = plt.subplots()
101
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
102
    ax.figure.colorbar(im, ax=ax)
103
    ax.set(xticks=np.arange(cm.shape[1]),
104
           yticks=np.arange(cm.shape[0]),
105
           xticklabels=classes, yticklabels=classes,
106
           title=title,
107
           ylabel='True label',
108
           xlabel='Predicted label')
109
110
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
111
             rotation_mode="anchor")
112
113
    fmt = '.2f' if normalize else 'd'
114
    thresh = cm.max() / 2.
115
    for i in range(cm.shape[0]):
116
        for j in range(cm.shape[1]):
117
            ax.text(j, i, format(cm[i, j], fmt),
118
                    ha="center", va="center",
119
                    color="white" if cm[i, j] > thresh else "black")
120
    fig.tight_layout()
121
    mkdir_recursive('results')
122
    fig.savefig('results/confusionMatrix-'+feature+'.eps', format='eps', dpi=1000)
123
    return ax
124
125
126
# Precision-Recall curves and ROC curves for each class
127
def PR_ROC_curves(ytrue, ypred, classes, ypred_mat):
128
    ybool = ypred == ytrue
129
    f, ax = plt.subplots(3,4,figsize=(10, 10))
130
    ax = [a for i in ax for a in i]
131
132
    e = -1
133
    for c in classes:
134
        idx1 = [n for n,x in enumerate(ytrue) if classes[x]==c]
135
        idx2 = [n for n,x in enumerate(ypred) if classes[x]==c]
136
        idx = idx1+idx2
137
        if idx == []:
138
            continue
139
        bi_ytrue = ytrue[idx]
140
        bi_prob = ypred_mat[idx, :]
141
        bi_ybool = np.array(ybool[idx])
142
        bi_yscore = np.array([bi_prob[x][bi_ytrue[x]] for x in range(len(idx))])
143
        try:
144
            print("AUC for {}: {}".format(c, roc_auc_score(bi_ybool+0, bi_yscore)))
145
            e+=1
146
        except ValueError:
147
            continue
148
        ppvs, senss, thresholds = precision_recall_curve(bi_ybool, bi_yscore)
149
        cax = ax[2*e]
150
        cax.plot(ppvs, senss, lw=2, label="Model")
151
        cax.set_xlim(-0.008, 1.05)
152
        cax.set_ylim(0.0, 1.05)
153
        cax.set_title("Class {}".format(c))
154
        cax.set_xlabel('Sensitivity (Recall)')
155
        cax.set_ylabel('PPV (Precision)')
156
        cax.legend(loc=3)
157
158
        fpr, tpr, thresholds = roc_curve(bi_ybool, bi_yscore)
159
        cax2 = ax[2*e+1]
160
        cax2.plot(fpr, tpr, lw=2, label="Model")
161
        cax2.set_xlim(-0.1, 1.)
162
        cax2.set_ylim(0.0, 1.05)
163
        cax2.set_title("Class {}".format(c))
164
        cax2.set_xlabel('1 - Specificity')
165
        cax2.set_ylabel('Sensitivity')
166
        cax2.legend(loc=4)
167
168
    mkdir_recursive("results")
169
    plt.savefig("results/model_prec_recall_and_roc.eps",
170
        dpi=400,
171
        format='eps',
172
        bbox_inches='tight')
173
    plt.close()
174
175
def print_results(config, model, Xval, yval, classes):
176
    model2 = model
177
    if config.trained_model:
178
        model.load_weights(config.trained_model)
179
    else:    
180
        model.load_weights('models/{}-latest.keras'.format(config.feature))
181
    # to combine different trained models. On testing  
182
    if config.ensemble:
183
        model2.load_weight('models/weights-V1.keras')
184
        ypred_mat = (model.predict(Xval) + model2.predict(Xval))/2
185
    else:
186
        ypred_mat = model.predict(Xval)  
187
188
    print("yval.shape",yval)
189
190
    ytrue = np.argmax(yval,axis=1)
191
    yscore = np.array([ypred_mat[x][ytrue[x]] for x in range(len(yval))])
192
    ypred = np.argmax(ypred_mat, axis=1)
193
    print(classification_report(ytrue, ypred))
194
    plot_confusion_matrix(ytrue, ypred, classes, feature=config.feature, normalize=False)
195
    print("F1 score:", f1_score(ytrue, ypred, average=None))
196
    PR_ROC_curves(ytrue, ypred, classes, ypred_mat)
197
198
def add_noise(config):
199
    noises = dict()
200
    noises["trainset"] = list()
201
    noises["testset"] = list() 
202
    import csv
203
    try:
204
        testlabel = list(csv.reader(open('training2017/REFERENCE.csv')))
205
    except:
206
        cmd = "curl -O https://archive.physionet.org/challenge/2017/training2017.zip"
207
        os.system(cmd)
208
        os.system("unzip training2017.zip")
209
        testlabel = list(csv.reader(open('training2017/REFERENCE.csv')))
210
    for i, label in enumerate(testlabel):
211
      if label[1] == '~':
212
        filename = 'training2017/'+ label[0] + '.mat'
213
        from scipy.io import loadmat
214
        noise = loadmat(filename)
215
        noise = noise['val']
216
        _, size = noise.shape
217
        noise = noise.reshape(size,)
218
        noise = np.nan_to_num(noise) # removing NaNs and Infs
219
        from scipy.signal import resample
220
        noise= resample(noise, int(len(noise) * 360 / 300) ) # resample to match the data sampling rate 360(mit), 300(cinc)
221
        from sklearn import preprocessing
222
        noise = preprocessing.scale(noise)
223
        noise = noise/1000*6 # rough normalize, to be improved 
224
        from scipy.signal import find_peaks
225
        peaks, _ = find_peaks(noise, distance=150)
226
        choices = 10 # 256*10 from 9000
227
        picked_peaks = np.random.choice(peaks, choices, replace=False)
228
        for j, peak in enumerate(picked_peaks):
229
          if peak > config.input_size//2 and peak < len(noise) - config.input_size//2:
230
              start,end  = peak-config.input_size//2, peak+config.input_size//2
231
              if i > len(testlabel)/6:
232
                noises["trainset"].append(noise[start:end].tolist())
233
              else:
234
                noises["testset"].append(noise[start:end].tolist())
235
    return noises
236
237
def preprocess(data, config):
238
    sr = config.sample_rate
239
    if sr == None:
240
      sr = 300
241
    data = np.nan_to_num(data) # removing NaNs and Infs
242
    from scipy.signal import resample
243
    data = resample(data, int(len(data) * 360 / sr) ) # resample to match the data sampling rate 360(mit), 300(cinc)
244
    from sklearn import preprocessing
245
    data = preprocessing.scale(data)
246
    from scipy.signal import find_peaks
247
    peaks, _ = find_peaks(data, distance=150)
248
    data = data.reshape(1,len(data))
249
    data = np.expand_dims(data, axis=2) # required by Keras
250
    return data, peaks
251
252
# predict 
253
def uploadedData(filename, csvbool = True):
254
    if csvbool:
255
      csvlist = list()
256
      with open(filename, 'r') as csvfile:
257
        for e in csvfile:
258
          if len(e.split()) == 1 :
259
            csvlist.append(float(e))
260
          else:
261
            csvlist.append(e)
262
    return csvlist