a b/tf2/seq_seq_annot_aami.py
1
import numpy as np
2
import matplotlib.pyplot as plt
3
import scipy.io as spio
4
from  sklearn.preprocessing import MinMaxScaler
5
import random
6
import time
7
import os
8
from datetime import datetime
9
from sklearn.metrics import confusion_matrix
10
import tensorflow as tf
11
import tensorflow_addons as tfa
12
from imblearn.over_sampling import SMOTE
13
from sklearn.model_selection import train_test_split
14
import argparse
15
random.seed(654)
16
def read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100):
17
    def normalize(data):
18
        data = np.nan_to_num(data)  # removing NaNs and Infs
19
        data = data - np.mean(data)
20
        data = data / np.std(data)
21
        return data
22
23
    # read data
24
    data = []
25
    samples = spio.loadmat(filename + ".mat")
26
    samples = samples['s2s_mitbih']
27
    values = samples[0]['seg_values']
28
    labels = samples[0]['seg_labels']
29
    num_annots = sum([item.shape[0] for item in values])
30
31
    n_seqs = num_annots / max_time
32
    #  add all segments(beats) together
33
    l_data = 0
34
    for i, item in enumerate(values):
35
        l = item.shape[0]
36
        for itm in item:
37
            if l_data == n_seqs * max_time:
38
                break
39
            data.append(itm[0])
40
            l_data = l_data + 1
41
42
    #  add all labels together
43
    l_lables  = 0
44
    t_lables = []
45
    for i, item in enumerate(labels):
46
        if len(t_lables)==n_seqs*max_time:
47
            break
48
        item= item[0]
49
        for lebel in item:
50
            if l_lables == n_seqs * max_time:
51
                break
52
            t_lables.append(str(lebel))
53
            l_lables = l_lables + 1
54
55
    del values
56
    data = np.asarray(data)
57
    shape_v = data.shape
58
    data = np.reshape(data, [shape_v[0], -1])
59
    t_lables = np.array(t_lables)
60
    _data  = np.asarray([],dtype=np.float64).reshape(0,shape_v[1])
61
    _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,)
62
    for cl in classes:
63
        _label = np.where(t_lables == cl)
64
        permute = np.random.permutation(len(_label[0]))
65
        _label = _label[0][permute[:max_nlabel]]
66
67
        # _label = _label[0][:max_nlabel]
68
        # permute = np.random.permutation(len(_label))
69
        # _label = _label[permute]
70
        _data = np.concatenate((_data, data[_label]))
71
        _labels = np.concatenate((_labels, t_lables[_label]))
72
73
    data = _data[:int(len(_data)/ max_time) * max_time, :]
74
    _labels = _labels[:int(len(_data) / max_time) * max_time]
75
76
    # data = _data
77
    #  split data into sublist of 100=se_len values
78
    data = [data[i:i + max_time] for i in range(0, len(data), max_time)]
79
    labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)]
80
    # shuffle
81
    permute = np.random.permutation(len(labels))
82
    data = np.asarray(data)
83
    labels = np.asarray(labels)
84
    data= data[permute]
85
    labels = labels[permute]
86
87
    print('Records processed!')
88
89
    return data, labels
90
def evaluate_metrics(confusion_matrix):
91
    # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal
92
    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
93
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
94
    TP = np.diag(confusion_matrix)
95
    TN = confusion_matrix.sum() - (FP + FN + TP)
96
    # Sensitivity, hit rate, recall, or true positive rate
97
    TPR = TP / (TP + FN)
98
    # Specificity or true negative rate
99
    TNR = TN / (TN + FP)
100
    # Precision or positive predictive value
101
    PPV = TP / (TP + FP)
102
    # Negative predictive value
103
    NPV = TN / (TN + FN)
104
    # Fall out or false positive rate
105
    FPR = FP / (FP + TN)
106
    # False negative rate
107
    FNR = FN / (TP + FN)
108
    # False discovery rate
109
    FDR = FP / (TP + FP)
110
111
    # Overall accuracy
112
    ACC = (TP + TN) / (TP + FP + FN + TN)
113
    # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN))
114
    ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average)
115
116
    return ACC_macro, ACC, TPR, TNR, PPV
117
def batch_data(x, y, batch_size):
118
    shuffle = np.random.permutation(len(x))
119
    start = 0
120
    #     from IPython.core.debugger import Tracer; Tracer()()
121
    x = x[shuffle]
122
    y = y[shuffle]
123
    while start + batch_size <= len(x):
124
        yield x[start:start + batch_size], y[start:start + batch_size]
125
        start += batch_size
126
def build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False):
127
    _inputs = tf.reshape(inputs, [-1, n_channels, int(input_depth / n_channels)])
128
    # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels])
129
130
    # #(batch*max_time, 280, 1) --> (N, 280, 18)
131
    conv1 = tf.compat.v1.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1,
132
                             padding='same', activation=tf.nn.relu)
133
    max_pool_1 = tf.compat.v1.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')
134
135
    conv2 = tf.compat.v1.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1,
136
                             padding='same', activation=tf.nn.relu)
137
    max_pool_2 = tf.compat.v1.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')
138
139
    conv3 = tf.compat.v1.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1,
140
                             padding='same', activation=tf.nn.relu)
141
142
    shape = conv3.get_shape().as_list()
143
    data_input_embed = tf.reshape(conv3, (-1, max_time, shape[1] * shape[2]))
144
145
    # timesteps = max_time
146
    #
147
    # lstm_in = tf.unstack(data_input_embed, timesteps, 1)
148
    # lstm_size = 128
149
    # # Get lstm cell output
150
    # # Add LSTM layers
151
    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
152
    # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32)
153
    # data_input_embed = tf.stack(data_input_embed, 1)
154
155
    # shape = data_input_embed.get_shape().as_list()
156
157
    embed_size = 10  # 128 lstm_size # shape[1]*shape[2]
158
159
    # Embedding layers
160
    output_embedding = tf.Variable(tf.random.uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding')
161
    data_output_embed = tf.nn.embedding_lookup(params=output_embedding, ids=dec_inputs)
162
163
    with tf.compat.v1.variable_scope("encoding") as encoding_scope:
164
        if not bidirectional:
165
166
            # Regular approach with LSTM units
167
            lstm_enc = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
168
            _, last_state = tf.compat.v1.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32)
169
170
        else:
171
172
            # Using a bidirectional LSTM architecture instead
173
            enc_fw_cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
174
            enc_bw_cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
175
176
            ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.compat.v1.nn.bidirectional_dynamic_rnn(
177
                cell_fw=enc_fw_cell,
178
                cell_bw=enc_bw_cell,
179
                inputs=data_input_embed,
180
                dtype=tf.float32)
181
            enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1)
182
            enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1)
183
            last_state = tf.nn.rnn_cell.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h)
184
185
    with tf.compat.v1.variable_scope("decoding") as decoding_scope:
186
        if not bidirectional:
187
            lstm_dec = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
188
        else:
189
            lstm_dec = tf.compat.v1.nn.rnn_cell.LSTMCell(2 * num_units)
190
191
        dec_outputs, _ = tf.compat.v1.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state)
192
193
    logits = tf.compat.v1.layers.dense(dec_outputs, units=len(char2numY), use_bias=True)
194
195
    return logits
196
def str2bool(v):
197
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
198
        return True
199
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
200
        return False
201
    else:
202
        raise argparse.ArgumentTypeError('Boolean value expected.')
203
def main():
204
    parser = argparse.ArgumentParser()
205
206
    parser.add_argument('--epochs', type=int, default=500)
207
    parser.add_argument('--max_time', type=int, default=10)
208
    parser.add_argument('--test_steps', type=int, default=10)
209
    parser.add_argument('--batch_size', type=int, default=20)
210
    parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami')
211
    parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False'))
212
    # parser.add_argument('--lstm_layers', type=int, default=2)
213
    parser.add_argument('--num_units', type=int, default=128)
214
    parser.add_argument('--n_oversampling', type=int, default=10000)
215
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq')
216
    parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih.ckpt')
217
    parser.add_argument('--classes', nargs='+', type=chr,
218
                        default=['F','N', 'S','V'])
219
    args = parser.parse_args()
220
    run_program(args)
221
def run_program(args):
222
    print(args)
223
    max_time = args.max_time # 5 3 second best 10# 40 # 100
224
    epochs = args.epochs # 300
225
    batch_size = args.batch_size # 10
226
    num_units = args.num_units
227
    bidirectional = args.bidirectional
228
    # lstm_layers = args.lstm_layers
229
    n_oversampling = args.n_oversampling
230
    checkpoint_dir = args.checkpoint_dir
231
    ckpt_name = args.ckpt_name
232
    test_steps = args.test_steps
233
    classes= args.classes
234
    filename = args.data_dir
235
236
    X, Y = read_mitbih(filename,max_time,classes=classes,max_nlabel=100000) #11000
237
    print(("# of sequences: ", len(X)))
238
    input_depth = X.shape[2]
239
    n_channels = 10
240
    classes = np.unique(Y)
241
    char2numY = dict(list(zip(classes, list(range(len(classes))))))
242
    n_classes = len(classes)
243
    print(('Classes: ', classes))
244
    for cl in classes:
245
        ind = np.where(classes == cl)[0][0]
246
        print((cl, len(np.where(Y.flatten()==cl)[0])))
247
    # char2numX['<PAD>'] = len(char2numX)
248
    # num2charX = dict(zip(char2numX.values(), char2numX.keys()))
249
    # max_len = max([len(date) for date in x])
250
    #
251
    # x = [[char2numX['<PAD>']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]
252
    # print(''.join([num2charX[x_] for x_ in x[4]]))
253
    # x = np.array(x)
254
255
    char2numY['<GO>'] = len(char2numY)
256
    num2charY = dict(list(zip(list(char2numY.values()), list(char2numY.keys()))))
257
258
    Y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in Y]
259
    Y = np.array(Y)
260
261
    x_seq_length = len(X[0])
262
    y_seq_length = len(Y[0])- 1
263
264
    # Placeholders
265
    tf.compat.v1.disable_eager_execution()
266
    inputs = tf.compat.v1.placeholder(tf.float32, [None, max_time, input_depth], name = 'inputs')
267
    targets = tf.compat.v1.placeholder(tf.int32, (None, None), 'targets')
268
    dec_inputs = tf.compat.v1.placeholder(tf.int32, (None, None), 'output')
269
270
    # logits = build_network(inputs,dec_inputs=dec_inputs)
271
    logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time,
272
                  bidirectional=bidirectional)
273
    # decoder_prediction = tf.argmax(logits, 2)
274
    # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong
275
    # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1)
276
277
    with tf.compat.v1.name_scope("optimization"):
278
        # Loss function
279
        vars = tf.compat.v1.trainable_variables()
280
        beta = 0.001
281
        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars
282
                            if 'bias' not in v.name]) * beta
283
        loss = tfa.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))
284
        # Optimizer
285
        loss = tf.reduce_mean(input_tensor=loss + lossL2)
286
        optimizer = tf.compat.v1.train.RMSPropOptimizer(1e-3).minimize(loss)
287
288
289
    # split the dataset into the training and test sets
290
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
291
292
    # over-sampling: SMOTE
293
    X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1])
294
    y_train= y_train[:,1:].flatten()
295
296
    nums = []
297
    for cl in classes:
298
        ind = np.where(classes == cl)[0][0]
299
        nums.append(len(np.where(y_train.flatten()==ind)[0]))
300
    # ratio={0:nums[3],1:nums[1],2:nums[3],3:nums[3]} # the best with 11000 for N
301
    ratio={0:n_oversampling,1:nums[1],2:n_oversampling,3:n_oversampling}
302
    sm = SMOTE(random_state=12,ratio=ratio)
303
    X_train, y_train = sm.fit_sample(X_train, y_train)
304
305
    X_train = X_train[:int(X_train.shape[0]/max_time)*max_time,:]
306
    y_train = y_train[:int(X_train.shape[0]/max_time)*max_time]
307
308
    X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]])
309
    y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,])
310
    y_train= [[char2numY['<GO>']] + [y_ for y_ in date] for date in y_train]
311
    y_train = np.array(y_train)
312
313
    print(('Classes in the training set: ', classes))
314
    for cl in classes:
315
        ind = np.where(classes == cl)[0][0]
316
        print((cl, len(np.where(y_train.flatten()==ind)[0])))
317
    print ("------------------y_train samples--------------------")
318
    for ii in range(2):
319
      print((''.join([num2charY[y_] for y_ in list(y_train[ii+5])])))
320
    print ("------------------y_test samples--------------------")
321
    for ii in range(2):
322
      print((''.join([num2charY[y_] for y_ in list(y_test[ii+5])])))
323
324
    def test_model():
325
        # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size))
326
        acc_track = []
327
        sum_test_conf = []
328
        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)):
329
330
            dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']
331
            for i in range(y_seq_length):
332
                batch_logits = sess.run(logits,
333
                                        feed_dict={inputs: source_batch, dec_inputs: dec_input})
334
                prediction = batch_logits[:, -1].argmax(axis=-1)
335
                dec_input = np.hstack([dec_input, prediction[:, None]])
336
            # acc_track.append(np.mean(dec_input == target_batch))
337
            acc_track.append(dec_input[:, 1:] == target_batch[:, 1:])
338
            y_true= target_batch[:, 1:].flatten()
339
            y_pred = dec_input[:, 1:].flatten()
340
            sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=list(range(len(char2numY)-1))))
341
342
        sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0)
343
344
        # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track)))
345
346
        # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy],
347
        #                                           feed_dict={inputs: source_batch,
348
        #                                                      dec_inputs: dec_input[:, :-1],
349
        #                                                      targets: target_batch[:, 1:]})
350
        # print (mean_p_class)
351
        # print (accuracy_classes)
352
        acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf)
353
        print(('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg)))
354
        for index_ in range(n_classes):
355
            print(("\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(classes[index_],
356
                                                                                                          sensitivity[
357
                                                                                                              index_],
358
                                                                                                          specificity[
359
                                                                                                              index_],PPV[index_],
360
                                                                                                          acc[index_])))
361
        print(("\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(np.mean(sensitivity),np.mean(specificity),np.mean(PPV),np.mean(acc))))
362
        return acc_avg, acc, sensitivity, specificity, PPV
363
    loss_track = []
364
    def count_prameters():
365
        print(('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])))
366
367
    count_prameters()
368
369
    if (os.path.exists(checkpoint_dir) == False):
370
        os.mkdir(checkpoint_dir)
371
    # train the graph
372
    with tf.compat.v1.Session() as sess:
373
        sess.run(tf.compat.v1.global_variables_initializer())
374
        sess.run(tf.compat.v1.local_variables_initializer())
375
        saver = tf.compat.v1.train.Saver()
376
        print((str(datetime.now())))
377
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
378
        pre_acc_avg = 0.0
379
        if ckpt and ckpt.model_checkpoint_path:
380
            # # Restore
381
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
382
            # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name))
383
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
384
            # or 'load meta graph' and restore weights
385
            # saver = tf.train.import_meta_graph(ckpt_name+".meta")
386
            # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir))
387
            test_model()
388
        else:
389
390
            for epoch_i in range(epochs):
391
                start_time = time.time()
392
                train_acc = []
393
                for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):
394
                    _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],
395
                        feed_dict = {inputs: source_batch,
396
                                     dec_inputs: target_batch[:, :-1],
397
                                     targets: target_batch[:, 1:]})
398
                    loss_track.append(batch_loss)
399
                    train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:])
400
                # mean_p_class,accuracy_classes = sess.run([mean_accuracy,update_mean_accuracy],
401
                #                         feed_dict={inputs: source_batch,
402
                #                                               dec_inputs: target_batch[:, :-1],
403
                #                                               targets: target_batch[:, 1:]})
404
405
                # accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])
406
                accuracy = np.mean(train_acc)
407
                print(('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss,
408
                                                                                  accuracy, time.time() - start_time)))
409
410
                if epoch_i%test_steps==0:
411
                    acc_avg, acc, sensitivity, specificity, PPV= test_model()
412
413
                    print(('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)))
414
                    save_path = os.path.join(checkpoint_dir, ckpt_name)
415
                    saver.save(sess, save_path)
416
                    print(("Model saved in path: %s" % save_path))
417
418
                    # if np.nan_to_num(acc_avg) > pre_acc_avg:  # save the better model based on the f1 score
419
                    #     print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))
420
                    #     pre_acc_avg = acc_avg
421
                    #     save_path =os.path.join(checkpoint_dir, ckpt_name)
422
                    #     saver.save(sess, save_path)
423
                    #     print("The best model (till now) saved in path: %s" % save_path)
424
425
            plt.plot(loss_track)
426
            plt.show()
427
        print((str(datetime.now())))
428
        # test_model()
429
if __name__ == '__main__':
430
    main()