a b/EEGLearn/train.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Yang Wang
3
## School of Automation, Huazhong University of Science & Technology (HUST)
4
## wangyang_sky@hust.edu.cn
5
## Copyright (c) 2018
6
##
7
## This source code is licensed under the MIT-style license found in the
8
## LICENSE file in the root directory of this source tree
9
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10
11
#coding:utf-8
12
13
import os
14
import tensorflow as tf
15
import numpy as np
16
import scipy.io
17
import time
18
import datetime
19
20
from utils import reformatInput, load_or_generate_images, iterate_minibatches
21
22
from model import build_cnn, build_convpool_conv1d, build_convpool_lstm, build_convpool_mix
23
24
25
timestamp = datetime.datetime.now().strftime('%Y-%m-%d.%H.%M')
26
log_path = os.path.join("runs", timestamp)
27
28
29
model_type = '1dconv'      # ['1dconv', 'maxpool', 'lstm', 'mix', 'cnn']
30
log_path = log_path + '_' + model_type
31
32
batch_size = 32
33
dropout_rate = 0.5
34
35
input_shape = [32, 32, 3]   # 1024
36
nb_class = 4
37
n_colors = 3
38
39
# whether to train cnn first, and load its weight for multi-frame model
40
reuse_cnn_flag = False
41
42
# learning_rate for different models
43
lrs = {
44
    'cnn': 1e-3,
45
    '1dconv': 1e-4,
46
    'lstm': 1e-4,
47
    'mix': 1e-4,
48
}
49
50
weight_decay = 1e-4
51
learning_rate = lrs[model_type] / 32 * batch_size
52
optimizer = tf.train.AdamOptimizer
53
54
num_epochs = 60
55
56
def train(images, labels, fold, model_type, batch_size, num_epochs, subj_id=0, reuse_cnn=False, 
57
    dropout_rate=dropout_rate ,learning_rate_default=1e-3, Optimizer=tf.train.AdamOptimizer, log_path=log_path):
58
    """
59
    A sample training function which loops over the training set and evaluates the network
60
    on the validation set after each epoch. Evaluates the network on the training set
61
    whenever the
62
    :param images: input images
63
    :param labels: target labels
64
    :param fold: tuple of (train, test) index numbers
65
    :param model_type: model type ('cnn', '1dconv', 'lstm', 'mix')
66
    :param batch_size: batch size for training
67
    :param num_epochs: number of epochs of dataset to go over for training
68
    :param subj_id: the id of fold for storing log and the best model
69
    :param reuse_cnn: whether to train cnn first, and load its weight for multi-frame model
70
    :return: none
71
    """
72
73
    with tf.name_scope('Inputs'):
74
        input_var = tf.placeholder(tf.float32, [None, None, 32, 32, n_colors], name='X_inputs')
75
        target_var = tf.placeholder(tf.int64, [None], name='y_inputs')
76
        tf_is_training = tf.placeholder(tf.bool, None, name='is_training')
77
78
    num_classes = len(np.unique(labels))
79
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = reformatInput(images, labels, fold)
80
81
82
    print('Train set label and proportion:\t', np.unique(y_train, return_counts=True))
83
    print('Val   set label and proportion:\t', np.unique(y_val, return_counts=True))
84
    print('Test  set label and proportion:\t', np.unique(y_test, return_counts=True))
85
86
    print('The shape of X_trian:\t', X_train.shape)
87
    print('The shape of X_val:\t', X_val.shape)
88
    print('The shape of X_test:\t', X_test.shape)
89
    
90
91
    print("Building model and compiling functions...")
92
    if model_type == '1dconv':
93
        network = build_convpool_conv1d(input_var, num_classes, train=tf_is_training, 
94
                            dropout_rate=dropout_rate, name='CNN_Conv1d'+'_sbj'+str(subj_id))
95
    elif model_type == 'lstm':
96
        network = build_convpool_lstm(input_var, num_classes, 100, train=tf_is_training, 
97
                            dropout_rate=dropout_rate, name='CNN_LSTM'+'_sbj'+str(subj_id))
98
    elif model_type == 'mix':
99
        network = build_convpool_mix(input_var, num_classes, 100, train=tf_is_training, 
100
                            dropout_rate=dropout_rate, name='CNN_Mix'+'_sbj'+str(subj_id))
101
    elif model_type == 'cnn':
102
        with tf.name_scope(name='CNN_layer'+'_fold'+str(subj_id)):
103
            network = build_cnn(input_var)  # output shape [None, 4, 4, 128]
104
            convpool_flat = tf.reshape(network, [-1, 4*4*128])
105
            h_fc1_drop1 = tf.layers.dropout(convpool_flat, rate=dropout_rate, training=tf_is_training, name='dropout_1')
106
            h_fc1 = tf.layers.dense(h_fc1_drop1, 256, activation=tf.nn.relu, name='fc_relu_256')
107
            h_fc1_drop2 = tf.layers.dropout(h_fc1, rate=dropout_rate, training=tf_is_training, name='dropout_2')
108
            network = tf.layers.dense(h_fc1_drop2, num_classes, name='fc_softmax')
109
            # the loss function contains the softmax activation
110
    else:
111
        raise ValueError("Model not supported ['1dconv', 'maxpool', 'lstm', 'mix', 'cnn']")
112
113
    Train_vars = tf.trainable_variables()
114
115
    prediction = network
116
117
    with tf.name_scope('Loss'):
118
        l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in Train_vars if 'kernel' in v.name])
119
        ce_loss = tf.losses.sparse_softmax_cross_entropy(labels=target_var, logits=prediction)
120
        _loss = ce_loss + weight_decay*l2_loss
121
122
    # decay_steps learning rate decay
123
    decay_steps = 3*(len(y_train)//batch_size)   # len(X_train)//batch_size  the training steps for an epcoh
124
    with tf.name_scope('Optimizer'):
125
        # learning_rate = learning_rate_default * Decay_rate^(global_steps/decay_steps)
126
        global_steps = tf.Variable(0, name="global_step", trainable=False)
127
        learning_rate = tf.train.exponential_decay(     # learning rate decay
128
            learning_rate_default,  # Base learning rate.
129
            global_steps,
130
            decay_steps,
131
            0.95,  # Decay rate.
132
            staircase=True)
133
        optimizer = Optimizer(learning_rate)    # GradientDescentOptimizer  AdamOptimizer
134
        train_op = optimizer.minimize(_loss, global_step=global_steps, var_list=Train_vars)
135
136
    with tf.name_scope('Accuracy'):
137
        prediction = tf.argmax(prediction, axis=1)
138
        correct_prediction = tf.equal(prediction, target_var)
139
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
140
141
    # Output directory for models and summaries
142
    # choose different path for different model and subject
143
    out_dir = os.path.abspath(os.path.join(os.path.curdir, log_path, (model_type+'_'+str(subj_id)) ))
144
    print("Writing to {}\n".format(out_dir))
145
146
    # Summaries for loss, accuracy and learning_rate
147
    loss_summary = tf.summary.scalar('loss', _loss)
148
    acc_summary = tf.summary.scalar('train_acc', accuracy)
149
    lr_summary = tf.summary.scalar('learning_rate', learning_rate)
150
151
    # Train Summaries
152
    train_summary_op = tf.summary.merge([loss_summary, acc_summary, lr_summary])
153
    train_summary_dir = os.path.join(out_dir, "summaries", "train")
154
    train_summary_writer = tf.summary.FileWriter(train_summary_dir, tf.get_default_graph())
155
156
    # Dev summaries
157
    dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
158
    dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
159
    dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, tf.get_default_graph())
160
161
    # Test summaries
162
    test_summary_op = tf.summary.merge([loss_summary, acc_summary])
163
    test_summary_dir = os.path.join(out_dir, "summaries", "test")
164
    test_summary_writer = tf.summary.FileWriter(test_summary_dir, tf.get_default_graph())
165
166
167
    # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
168
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
169
    checkpoint_prefix = os.path.join(checkpoint_dir, model_type)
170
    if not os.path.exists(checkpoint_dir):
171
        os.makedirs(checkpoint_dir)
172
173
174
    if model_type != 'cnn' and reuse_cnn:
175
        # saver for reuse the CNN weight
176
        reuse_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='VGG_NET_CNN')
177
        original_saver = tf.train.Saver(reuse_vars)         # Pass the variables as a list
178
179
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
180
181
    print("Starting training...")
182
    total_start_time = time.time()
183
    best_validation_accu = 0
184
185
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
186
    with tf.Session() as sess:
187
        sess.run(init_op)
188
        if model_type != 'cnn' and reuse_cnn:
189
            cnn_model_path = os.path.abspath(
190
                                os.path.join(
191
                                    os.path.curdir, log_path, ('cnn_'+str(subj_id)), 'checkpoints' ))
192
            cnn_model_path = tf.train.latest_checkpoint(cnn_model_path)
193
            print('-'*20)
194
            print('Load cnn model weight for multi-frame model from {}'.format(cnn_model_path))
195
            original_saver.restore(sess, cnn_model_path)
196
197
        stop_count = 0  # count for earlystopping
198
        for epoch in range(num_epochs):
199
            print('-'*50)
200
            # Train set
201
            train_err = train_acc = train_batches = 0
202
            start_time = time.time()
203
            for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=False):
204
                inputs, targets = batch
205
                summary, _, pred, loss, acc = sess.run([train_summary_op, train_op, prediction, _loss, accuracy], 
206
                    {input_var: inputs, target_var: targets, tf_is_training: True})
207
                train_acc += acc
208
                train_err += loss
209
                train_batches += 1
210
                train_summary_writer.add_summary(summary, sess.run(global_steps))
211
212
            av_train_err = train_err / train_batches
213
            av_train_acc = train_acc / train_batches
214
215
            # Val set
216
            summary, pred, av_val_err, av_val_acc = sess.run([dev_summary_op, prediction, _loss, accuracy],
217
                    {input_var: X_val, target_var: y_val, tf_is_training: False})
218
            dev_summary_writer.add_summary(summary, sess.run(global_steps))
219
220
            
221
            print("Epoch {} of {} took {:.3f}s".format(
222
                epoch + 1, num_epochs, time.time() - start_time))
223
            
224
            fmt_str = "Train \tEpoch [{:d}/{:d}]  train_Loss: {:.4f}\ttrain_Acc: {:.2f}"
225
            print_str = fmt_str.format(epoch + 1, num_epochs, av_train_err, av_train_acc*100)
226
            print(print_str)
227
228
            fmt_str = "Val \tEpoch [{:d}/{:d}]  val_Loss: {:.4f}\tval_Acc: {:.2f}"
229
            print_str = fmt_str.format(epoch + 1, num_epochs, av_val_err, av_val_acc*100)
230
            print(print_str)
231
            
232
            # Test set
233
            summary, pred, av_test_err, av_test_acc = sess.run([test_summary_op, prediction, _loss, accuracy],
234
                {input_var: X_test, target_var: y_test, tf_is_training: False})
235
            test_summary_writer.add_summary(summary, sess.run(global_steps))
236
            
237
            fmt_str = "Test \tEpoch [{:d}/{:d}]  test_Loss: {:.4f}\ttest_Acc: {:.2f}"
238
            print_str = fmt_str.format(epoch + 1, num_epochs, av_test_err, av_test_acc*100)
239
            print(print_str)
240
241
            if av_val_acc > best_validation_accu:   # early_stoping
242
                stop_count = 0
243
                eraly_stoping_epoch = epoch
244
                best_validation_accu = av_val_acc
245
                test_acc_val = av_test_acc
246
                saver.save(sess, checkpoint_prefix, global_step=sess.run(global_steps))
247
            else:
248
                stop_count += 1
249
                if stop_count >= 10: # stop training if val_acc dose not imporve for over 10 epochs
250
                    break
251
252
        train_batches = train_acc = 0
253
        for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=False):
254
            inputs, targets = batch
255
            acc = sess.run(accuracy, {input_var: X_train, target_var: y_train, tf_is_training: False})
256
            train_acc += acc
257
            train_batches += 1
258
259
        last_train_acc = train_acc / train_batches
260
        
261
        
262
        last_val_acc = av_val_acc
263
        last_test_acc = av_test_acc
264
        print('-'*50)
265
        print('Time in total:', time.time()-total_start_time)
266
        print("Best validation accuracy:\t\t{:.2f} %".format(best_validation_accu * 100))
267
        print("Test accuracy when got the best validation accuracy:\t\t{:.2f} %".format(test_acc_val * 100))
268
        print('-'*50)
269
        print("Last train accuracy:\t\t{:.2f} %".format(last_train_acc * 100))
270
        print("Last validation accuracy:\t\t{:.2f} %".format(last_val_acc * 100))
271
        print("Last test accuracy:\t\t\t\t{:.2f} %".format(last_test_acc * 100))
272
        print('Early Stopping at epoch: {}'.format(eraly_stoping_epoch+1))
273
274
    train_summary_writer.close()
275
    dev_summary_writer.close()
276
    test_summary_writer.close()
277
    return [last_train_acc, best_validation_accu, test_acc_val, last_val_acc, last_test_acc]
278
279
280
281
def train_all_model(num_epochs=3000):
282
    nums_subject = 13
283
    # Leave-Subject-Out cross validation
284
    subj_nums = np.squeeze(scipy.io.loadmat('../SampleData/trials_subNums.mat')['subjectNum'])
285
    fold_pairs = []
286
    for i in np.unique(subj_nums):
287
        ts = subj_nums == i
288
        tr = np.squeeze(np.nonzero(np.bitwise_not(ts)))
289
        ts = np.squeeze(np.nonzero(ts))
290
        np.random.shuffle(tr)
291
        np.random.shuffle(ts)
292
        fold_pairs.append((tr, ts))
293
294
295
    images_average, images_timewin, labels = load_or_generate_images(
296
                                                file_path='../SampleData/', average_image=3)
297
298
299
    print('*'*200)
300
    acc_buf = []
301
    for subj_id in range(nums_subject):
302
        print('-'*100)
303
        
304
        if model_type == 'cnn':
305
            print('The subjects', subj_id, '\t\t Training the ' + 'cnn' + ' Model...')
306
            acc_temp = train(images_average, labels, fold_pairs[subj_id], 'cnn', 
307
                                batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id,
308
                                learning_rate_default=lrs['cnn'], Optimizer=optimizer, log_path=log_path)
309
            acc_buf.append(acc_temp)
310
            tf.reset_default_graph()
311
            print('Done!')
312
313
        else:
314
            # whether to train cnn first, and load its weight for multi-frame model
315
            if reuse_cnn_flag is True:
316
                print('The subjects', subj_id, '\t\t Training the ' + 'cnn' + ' Model...')
317
                acc_temp = train(images_average, labels, fold_pairs[subj_id], 'cnn', 
318
                                    batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id,
319
                                    learning_rate_default=lrs['cnn'], Optimizer=optimizer, log_path=log_path)
320
                # acc_buf.append(acc_temp)
321
                tf.reset_default_graph()
322
                print('Done!')
323
        
324
            print('The subjects', subj_id, '\t\t Training the ' + model_type + ' Model...')
325
            print('Load the CNN model weight for backbone...')
326
            acc_temp = train(images_timewin, labels, fold_pairs[subj_id], model_type, 
327
                            batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id, reuse_cnn=reuse_cnn_flag, 
328
                            learning_rate_default=learning_rate, Optimizer=optimizer, log_path=log_path)
329
                                
330
            acc_buf.append(acc_temp)
331
            tf.reset_default_graph()
332
            print('Done!')
333
        
334
        # return
335
336
    print('All folds for {} are done!'.format(model_type))
337
    acc_buf = (np.array(acc_buf)).T
338
    acc_mean = np.mean(acc_buf, axis=1).reshape(-1, 1)
339
    acc_buf = np.concatenate([acc_buf, acc_mean], axis=1)
340
    # the last column is the mean of current row
341
    print('Last_train_acc:\t', acc_buf[0], '\tmean :', np.mean(acc_buf[0][-1]))
342
    print('Best_val_acc:\t', acc_buf[1], '\tmean :', np.mean(acc_buf[1][-1]))
343
    print('Earlystopping_test_acc:\t', acc_buf[2], '\tmean :', np.mean(acc_buf[2][-1]))
344
    print('Last_val_acc:\t', acc_buf[3], '\tmean :', np.mean(acc_buf[3][-1]))
345
    print('Last_test_acc:\t', acc_buf[4], '\tmean :', np.mean(acc_buf[4][-1]))
346
    np.savetxt('./Accuracy_{}.csv'.format(model_type), acc_buf, fmt='%.4f', delimiter=',')
347
348
349
if __name__ == '__main__':
350
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
351
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
352
    np.random.seed(2018)
353
    tf.set_random_seed(2018)
354
355
    train_all_model(num_epochs=num_epochs)