a b/seq_seq_annot_aami.py
1
import numpy as np
2
import matplotlib.pyplot as plt
3
import scipy.io as spio
4
from  sklearn.preprocessing import MinMaxScaler
5
import random
6
import time
7
import os
8
from datetime import datetime
9
from sklearn.metrics import confusion_matrix
10
import tensorflow as tf
11
from imblearn.over_sampling import SMOTE
12
from sklearn.model_selection import train_test_split
13
import argparse
14
random.seed(654)
15
def read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100):
16
    def normalize(data):
17
        data = np.nan_to_num(data)  # removing NaNs and Infs
18
        data = data - np.mean(data)
19
        data = data / np.std(data)
20
        return data
21
22
    # read data
23
    data = []
24
    samples = spio.loadmat(filename + ".mat")
25
    samples = samples['s2s_mitbih']
26
    values = samples[0]['seg_values']
27
    labels = samples[0]['seg_labels']
28
    num_annots = sum([item.shape[0] for item in values])
29
30
    n_seqs = num_annots / max_time
31
    #  add all segments(beats) together
32
    l_data = 0
33
    for i, item in enumerate(values):
34
        l = item.shape[0]
35
        for itm in item:
36
            if l_data == n_seqs * max_time:
37
                break
38
            data.append(itm[0])
39
            l_data = l_data + 1
40
41
    #  add all labels together
42
    l_lables  = 0
43
    t_lables = []
44
    for i, item in enumerate(labels):
45
        if len(t_lables)==n_seqs*max_time:
46
            break
47
        item= item[0]
48
        for lebel in item:
49
            if l_lables == n_seqs * max_time:
50
                break
51
            t_lables.append(str(lebel))
52
            l_lables = l_lables + 1
53
54
    del values
55
    data = np.asarray(data)
56
    shape_v = data.shape
57
    data = np.reshape(data, [shape_v[0], -1])
58
    t_lables = np.array(t_lables)
59
    _data  = np.asarray([],dtype=np.float64).reshape(0,shape_v[1])
60
    _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,)
61
    for cl in classes:
62
        _label = np.where(t_lables == cl)
63
        permute = np.random.permutation(len(_label[0]))
64
        _label = _label[0][permute[:max_nlabel]]
65
66
        # _label = _label[0][:max_nlabel]
67
        # permute = np.random.permutation(len(_label))
68
        # _label = _label[permute]
69
        _data = np.concatenate((_data, data[_label]))
70
        _labels = np.concatenate((_labels, t_lables[_label]))
71
72
    data = _data[:(len(_data)/ max_time) * max_time, :]
73
    _labels = _labels[:(len(_data) / max_time) * max_time]
74
75
    # data = _data
76
    #  split data into sublist of 100=se_len values
77
    data = [data[i:i + max_time] for i in range(0, len(data), max_time)]
78
    labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)]
79
    # shuffle
80
    permute = np.random.permutation(len(labels))
81
    data = np.asarray(data)
82
    labels = np.asarray(labels)
83
    data= data[permute]
84
    labels = labels[permute]
85
86
    print('Records processed!')
87
88
    return data, labels
89
def evaluate_metrics(confusion_matrix):
90
    # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal
91
    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
92
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
93
    TP = np.diag(confusion_matrix)
94
    TN = confusion_matrix.sum() - (FP + FN + TP)
95
    # Sensitivity, hit rate, recall, or true positive rate
96
    TPR = TP / (TP + FN)
97
    # Specificity or true negative rate
98
    TNR = TN / (TN + FP)
99
    # Precision or positive predictive value
100
    PPV = TP / (TP + FP)
101
    # Negative predictive value
102
    NPV = TN / (TN + FN)
103
    # Fall out or false positive rate
104
    FPR = FP / (FP + TN)
105
    # False negative rate
106
    FNR = FN / (TP + FN)
107
    # False discovery rate
108
    FDR = FP / (TP + FP)
109
110
    # Overall accuracy
111
    ACC = (TP + TN) / (TP + FP + FN + TN)
112
    # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN))
113
    ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average)
114
115
    return ACC_macro, ACC, TPR, TNR, PPV
116
def batch_data(x, y, batch_size):
117
    shuffle = np.random.permutation(len(x))
118
    start = 0
119
    #     from IPython.core.debugger import Tracer; Tracer()()
120
    x = x[shuffle]
121
    y = y[shuffle]
122
    while start + batch_size <= len(x):
123
        yield x[start:start + batch_size], y[start:start + batch_size]
124
        start += batch_size
125
def build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False):
126
    _inputs = tf.reshape(inputs, [-1, n_channels, input_depth / n_channels])
127
    # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels])
128
129
    # #(batch*max_time, 280, 1) --> (N, 280, 18)
130
    conv1 = tf.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1,
131
                             padding='same', activation=tf.nn.relu)
132
    max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')
133
134
    conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1,
135
                             padding='same', activation=tf.nn.relu)
136
    max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')
137
138
    conv3 = tf.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1,
139
                             padding='same', activation=tf.nn.relu)
140
141
    shape = conv3.get_shape().as_list()
142
    data_input_embed = tf.reshape(conv3, (-1, max_time, shape[1] * shape[2]))
143
144
    # timesteps = max_time
145
    #
146
    # lstm_in = tf.unstack(data_input_embed, timesteps, 1)
147
    # lstm_size = 128
148
    # # Get lstm cell output
149
    # # Add LSTM layers
150
    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
151
    # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32)
152
    # data_input_embed = tf.stack(data_input_embed, 1)
153
154
    # shape = data_input_embed.get_shape().as_list()
155
156
    embed_size = 10  # 128 lstm_size # shape[1]*shape[2]
157
158
    # Embedding layers
159
    output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding')
160
    data_output_embed = tf.nn.embedding_lookup(output_embedding, dec_inputs)
161
162
    with tf.variable_scope("encoding") as encoding_scope:
163
        if not bidirectional:
164
165
            # Regular approach with LSTM units
166
            lstm_enc = tf.contrib.rnn.LSTMCell(num_units)
167
            _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32)
168
169
        else:
170
171
            # Using a bidirectional LSTM architecture instead
172
            enc_fw_cell = tf.contrib.rnn.LSTMCell(num_units)
173
            enc_bw_cell = tf.contrib.rnn.LSTMCell(num_units)
174
175
            ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.nn.bidirectional_dynamic_rnn(
176
                cell_fw=enc_fw_cell,
177
                cell_bw=enc_bw_cell,
178
                inputs=data_input_embed,
179
                dtype=tf.float32)
180
            enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1)
181
            enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1)
182
            last_state = tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h)
183
184
    with tf.variable_scope("decoding") as decoding_scope:
185
        if not bidirectional:
186
            lstm_dec = tf.contrib.rnn.LSTMCell(num_units)
187
        else:
188
            lstm_dec = tf.contrib.rnn.LSTMCell(2 * num_units)
189
190
        dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state)
191
192
    logits = tf.layers.dense(dec_outputs, units=len(char2numY), use_bias=True)
193
194
    return logits
195
def str2bool(v):
196
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
197
        return True
198
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
199
        return False
200
    else:
201
        raise argparse.ArgumentTypeError('Boolean value expected.')
202
def main():
203
    parser = argparse.ArgumentParser()
204
205
    parser.add_argument('--epochs', type=int, default=500)
206
    parser.add_argument('--max_time', type=int, default=10)
207
    parser.add_argument('--test_steps', type=int, default=10)
208
    parser.add_argument('--batch_size', type=int, default=20)
209
    parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami')
210
    parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False'))
211
    # parser.add_argument('--lstm_layers', type=int, default=2)
212
    parser.add_argument('--num_units', type=int, default=128)
213
    parser.add_argument('--n_oversampling', type=int, default=10000)
214
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq')
215
    parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih.ckpt')
216
    parser.add_argument('--classes', nargs='+', type=chr,
217
                        default=['F','N', 'S','V'])
218
    args = parser.parse_args()
219
    run_program(args)
220
def run_program(args):
221
    print(args)
222
    max_time = args.max_time # 5 3 second best 10# 40 # 100
223
    epochs = args.epochs # 300
224
    batch_size = args.batch_size # 10
225
    num_units = args.num_units
226
    bidirectional = args.bidirectional
227
    # lstm_layers = args.lstm_layers
228
    n_oversampling = args.n_oversampling
229
    checkpoint_dir = args.checkpoint_dir
230
    ckpt_name = args.ckpt_name
231
    test_steps = args.test_steps
232
    classes= args.classes
233
    filename = args.data_dir
234
235
    X, Y = read_mitbih(filename,max_time,classes=classes,max_nlabel=100000) #11000
236
    print ("# of sequences: ", len(X))
237
    input_depth = X.shape[2]
238
    n_channels = 10
239
    classes = np.unique(Y)
240
    char2numY = dict(zip(classes, range(len(classes))))
241
    n_classes = len(classes)
242
    print ('Classes: ', classes)
243
    for cl in classes:
244
        ind = np.where(classes == cl)[0][0]
245
        print (cl, len(np.where(Y.flatten()==cl)[0]))
246
    # char2numX['<PAD>'] = len(char2numX)
247
    # num2charX = dict(zip(char2numX.values(), char2numX.keys()))
248
    # max_len = max([len(date) for date in x])
249
    #
250
    # x = [[char2numX['<PAD>']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]
251
    # print(''.join([num2charX[x_] for x_ in x[4]]))
252
    # x = np.array(x)
253
254
    char2numY['<GO>'] = len(char2numY)
255
    num2charY = dict(zip(char2numY.values(), char2numY.keys()))
256
257
    Y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in Y]
258
    Y = np.array(Y)
259
260
    x_seq_length = len(X[0])
261
    y_seq_length = len(Y[0])- 1
262
263
    # Placeholders
264
    inputs = tf.placeholder(tf.float32, [None, max_time, input_depth], name = 'inputs')
265
    targets = tf.placeholder(tf.int32, (None, None), 'targets')
266
    dec_inputs = tf.placeholder(tf.int32, (None, None), 'output')
267
268
    # logits = build_network(inputs,dec_inputs=dec_inputs)
269
    logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time,
270
                  bidirectional=bidirectional)
271
    # decoder_prediction = tf.argmax(logits, 2)
272
    # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong
273
    # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1)
274
275
    with tf.name_scope("optimization"):
276
        # Loss function
277
        vars = tf.trainable_variables()
278
        beta = 0.001
279
        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars
280
                            if 'bias' not in v.name]) * beta
281
        loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))
282
        # Optimizer
283
        loss = tf.reduce_mean(loss + lossL2)
284
        optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)
285
286
287
    # split the dataset into the training and test sets
288
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
289
290
    # over-sampling: SMOTE
291
    X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1])
292
    y_train= y_train[:,1:].flatten()
293
294
    nums = []
295
    for cl in classes:
296
        ind = np.where(classes == cl)[0][0]
297
        nums.append(len(np.where(y_train.flatten()==ind)[0]))
298
    # ratio={0:nums[3],1:nums[1],2:nums[3],3:nums[3]} # the best with 11000 for N
299
    ratio={0:n_oversampling,1:nums[1],2:n_oversampling,3:n_oversampling}
300
    sm = SMOTE(random_state=12,ratio=ratio)
301
    X_train, y_train = sm.fit_sample(X_train, y_train)
302
303
    X_train = X_train[:(X_train.shape[0]/max_time)*max_time,:]
304
    y_train = y_train[:(X_train.shape[0]/max_time)*max_time]
305
306
    X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]])
307
    y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,])
308
    y_train= [[char2numY['<GO>']] + [y_ for y_ in date] for date in y_train]
309
    y_train = np.array(y_train)
310
311
    print ('Classes in the training set: ', classes)
312
    for cl in classes:
313
        ind = np.where(classes == cl)[0][0]
314
        print (cl, len(np.where(y_train.flatten()==ind)[0]))
315
    print ("------------------y_train samples--------------------")
316
    for ii in range(2):
317
      print(''.join([num2charY[y_] for y_ in list(y_train[ii+5])]))
318
    print ("------------------y_test samples--------------------")
319
    for ii in range(2):
320
      print(''.join([num2charY[y_] for y_ in list(y_test[ii+5])]))
321
322
    def test_model():
323
        # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size))
324
        acc_track = []
325
        sum_test_conf = []
326
        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)):
327
328
            dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']
329
            for i in range(y_seq_length):
330
                batch_logits = sess.run(logits,
331
                                        feed_dict={inputs: source_batch, dec_inputs: dec_input})
332
                prediction = batch_logits[:, -1].argmax(axis=-1)
333
                dec_input = np.hstack([dec_input, prediction[:, None]])
334
            # acc_track.append(np.mean(dec_input == target_batch))
335
            acc_track.append(dec_input[:, 1:] == target_batch[:, 1:])
336
            y_true= target_batch[:, 1:].flatten()
337
            y_pred = dec_input[:, 1:].flatten()
338
            sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=range(len(char2numY)-1)))
339
340
        sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0)
341
342
        # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track)))
343
344
        # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy],
345
        #                                           feed_dict={inputs: source_batch,
346
        #                                                      dec_inputs: dec_input[:, :-1],
347
        #                                                      targets: target_batch[:, 1:]})
348
        # print (mean_p_class)
349
        # print (accuracy_classes)
350
        acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf)
351
        print('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg))
352
        for index_ in range(n_classes):
353
            print("\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(classes[index_],
354
                                                                                                          sensitivity[
355
                                                                                                              index_],
356
                                                                                                          specificity[
357
                                                                                                              index_],PPV[index_],
358
                                                                                                          acc[index_]))
359
        print("\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(np.mean(sensitivity),np.mean(specificity),np.mean(PPV),np.mean(acc)))
360
        return acc_avg, acc, sensitivity, specificity, PPV
361
    loss_track = []
362
    def count_prameters():
363
        print ('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
364
365
    count_prameters()
366
367
    if (os.path.exists(checkpoint_dir) == False):
368
        os.mkdir(checkpoint_dir)
369
    # train the graph
370
    with tf.Session() as sess:
371
        sess.run(tf.global_variables_initializer())
372
        sess.run(tf.local_variables_initializer())
373
        saver = tf.train.Saver()
374
        print(str(datetime.now()))
375
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
376
        pre_acc_avg = 0.0
377
        if ckpt and ckpt.model_checkpoint_path:
378
            # # Restore
379
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
380
            # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name))
381
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
382
            # or 'load meta graph' and restore weights
383
            # saver = tf.train.import_meta_graph(ckpt_name+".meta")
384
            # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir))
385
            test_model()
386
        else:
387
388
            for epoch_i in range(epochs):
389
                start_time = time.time()
390
                train_acc = []
391
                for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):
392
                    _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],
393
                        feed_dict = {inputs: source_batch,
394
                                     dec_inputs: target_batch[:, :-1],
395
                                     targets: target_batch[:, 1:]})
396
                    loss_track.append(batch_loss)
397
                    train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:])
398
                # mean_p_class,accuracy_classes = sess.run([mean_accuracy,update_mean_accuracy],
399
                #                         feed_dict={inputs: source_batch,
400
                #                                               dec_inputs: target_batch[:, :-1],
401
                #                                               targets: target_batch[:, 1:]})
402
403
                # accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])
404
                accuracy = np.mean(train_acc)
405
                print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss,
406
                                                                                  accuracy, time.time() - start_time))
407
408
                if epoch_i%test_steps==0:
409
                    acc_avg, acc, sensitivity, specificity, PPV= test_model()
410
411
                    print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))
412
                    save_path = os.path.join(checkpoint_dir, ckpt_name)
413
                    saver.save(sess, save_path)
414
                    print("Model saved in path: %s" % save_path)
415
416
                    # if np.nan_to_num(acc_avg) > pre_acc_avg:  # save the better model based on the f1 score
417
                    #     print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))
418
                    #     pre_acc_avg = acc_avg
419
                    #     save_path =os.path.join(checkpoint_dir, ckpt_name)
420
                    #     saver.save(sess, save_path)
421
                    #     print("The best model (till now) saved in path: %s" % save_path)
422
423
            plt.plot(loss_track)
424
            plt.show()
425
        print(str(datetime.now()))
426
        # test_model()
427
if __name__ == '__main__':
428
    main()
429
430
431
432
433