Diff of /EEGLearn/train.py [000000] .. [117083]

Switch to side-by-side view

--- a
+++ b/EEGLearn/train.py
@@ -0,0 +1,355 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Yang Wang
+## School of Automation, Huazhong University of Science & Technology (HUST)
+## wangyang_sky@hust.edu.cn
+## Copyright (c) 2018
+##
+## This source code is licensed under the MIT-style license found in the
+## LICENSE file in the root directory of this source tree
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+#coding:utf-8
+
+import os
+import tensorflow as tf
+import numpy as np
+import scipy.io
+import time
+import datetime
+
+from utils import reformatInput, load_or_generate_images, iterate_minibatches
+
+from model import build_cnn, build_convpool_conv1d, build_convpool_lstm, build_convpool_mix
+
+
+timestamp = datetime.datetime.now().strftime('%Y-%m-%d.%H.%M')
+log_path = os.path.join("runs", timestamp)
+
+
+model_type = '1dconv'      # ['1dconv', 'maxpool', 'lstm', 'mix', 'cnn']
+log_path = log_path + '_' + model_type
+
+batch_size = 32
+dropout_rate = 0.5
+
+input_shape = [32, 32, 3]   # 1024
+nb_class = 4
+n_colors = 3
+
+# whether to train cnn first, and load its weight for multi-frame model
+reuse_cnn_flag = False
+
+# learning_rate for different models
+lrs = {
+    'cnn': 1e-3,
+    '1dconv': 1e-4,
+    'lstm': 1e-4,
+    'mix': 1e-4,
+}
+
+weight_decay = 1e-4
+learning_rate = lrs[model_type] / 32 * batch_size
+optimizer = tf.train.AdamOptimizer
+
+num_epochs = 60
+
+def train(images, labels, fold, model_type, batch_size, num_epochs, subj_id=0, reuse_cnn=False, 
+    dropout_rate=dropout_rate ,learning_rate_default=1e-3, Optimizer=tf.train.AdamOptimizer, log_path=log_path):
+    """
+    A sample training function which loops over the training set and evaluates the network
+    on the validation set after each epoch. Evaluates the network on the training set
+    whenever the
+    :param images: input images
+    :param labels: target labels
+    :param fold: tuple of (train, test) index numbers
+    :param model_type: model type ('cnn', '1dconv', 'lstm', 'mix')
+    :param batch_size: batch size for training
+    :param num_epochs: number of epochs of dataset to go over for training
+    :param subj_id: the id of fold for storing log and the best model
+    :param reuse_cnn: whether to train cnn first, and load its weight for multi-frame model
+    :return: none
+    """
+
+    with tf.name_scope('Inputs'):
+        input_var = tf.placeholder(tf.float32, [None, None, 32, 32, n_colors], name='X_inputs')
+        target_var = tf.placeholder(tf.int64, [None], name='y_inputs')
+        tf_is_training = tf.placeholder(tf.bool, None, name='is_training')
+
+    num_classes = len(np.unique(labels))
+    (X_train, y_train), (X_val, y_val), (X_test, y_test) = reformatInput(images, labels, fold)
+
+
+    print('Train set label and proportion:\t', np.unique(y_train, return_counts=True))
+    print('Val   set label and proportion:\t', np.unique(y_val, return_counts=True))
+    print('Test  set label and proportion:\t', np.unique(y_test, return_counts=True))
+
+    print('The shape of X_trian:\t', X_train.shape)
+    print('The shape of X_val:\t', X_val.shape)
+    print('The shape of X_test:\t', X_test.shape)
+    
+
+    print("Building model and compiling functions...")
+    if model_type == '1dconv':
+        network = build_convpool_conv1d(input_var, num_classes, train=tf_is_training, 
+                            dropout_rate=dropout_rate, name='CNN_Conv1d'+'_sbj'+str(subj_id))
+    elif model_type == 'lstm':
+        network = build_convpool_lstm(input_var, num_classes, 100, train=tf_is_training, 
+                            dropout_rate=dropout_rate, name='CNN_LSTM'+'_sbj'+str(subj_id))
+    elif model_type == 'mix':
+        network = build_convpool_mix(input_var, num_classes, 100, train=tf_is_training, 
+                            dropout_rate=dropout_rate, name='CNN_Mix'+'_sbj'+str(subj_id))
+    elif model_type == 'cnn':
+        with tf.name_scope(name='CNN_layer'+'_fold'+str(subj_id)):
+            network = build_cnn(input_var)  # output shape [None, 4, 4, 128]
+            convpool_flat = tf.reshape(network, [-1, 4*4*128])
+            h_fc1_drop1 = tf.layers.dropout(convpool_flat, rate=dropout_rate, training=tf_is_training, name='dropout_1')
+            h_fc1 = tf.layers.dense(h_fc1_drop1, 256, activation=tf.nn.relu, name='fc_relu_256')
+            h_fc1_drop2 = tf.layers.dropout(h_fc1, rate=dropout_rate, training=tf_is_training, name='dropout_2')
+            network = tf.layers.dense(h_fc1_drop2, num_classes, name='fc_softmax')
+            # the loss function contains the softmax activation
+    else:
+        raise ValueError("Model not supported ['1dconv', 'maxpool', 'lstm', 'mix', 'cnn']")
+
+    Train_vars = tf.trainable_variables()
+
+    prediction = network
+
+    with tf.name_scope('Loss'):
+        l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in Train_vars if 'kernel' in v.name])
+        ce_loss = tf.losses.sparse_softmax_cross_entropy(labels=target_var, logits=prediction)
+        _loss = ce_loss + weight_decay*l2_loss
+
+    # decay_steps learning rate decay
+    decay_steps = 3*(len(y_train)//batch_size)   # len(X_train)//batch_size  the training steps for an epcoh
+    with tf.name_scope('Optimizer'):
+        # learning_rate = learning_rate_default * Decay_rate^(global_steps/decay_steps)
+        global_steps = tf.Variable(0, name="global_step", trainable=False)
+        learning_rate = tf.train.exponential_decay(     # learning rate decay
+            learning_rate_default,  # Base learning rate.
+            global_steps,
+            decay_steps,
+            0.95,  # Decay rate.
+            staircase=True)
+        optimizer = Optimizer(learning_rate)    # GradientDescentOptimizer  AdamOptimizer
+        train_op = optimizer.minimize(_loss, global_step=global_steps, var_list=Train_vars)
+
+    with tf.name_scope('Accuracy'):
+        prediction = tf.argmax(prediction, axis=1)
+        correct_prediction = tf.equal(prediction, target_var)
+        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
+
+    # Output directory for models and summaries
+    # choose different path for different model and subject
+    out_dir = os.path.abspath(os.path.join(os.path.curdir, log_path, (model_type+'_'+str(subj_id)) ))
+    print("Writing to {}\n".format(out_dir))
+
+    # Summaries for loss, accuracy and learning_rate
+    loss_summary = tf.summary.scalar('loss', _loss)
+    acc_summary = tf.summary.scalar('train_acc', accuracy)
+    lr_summary = tf.summary.scalar('learning_rate', learning_rate)
+
+    # Train Summaries
+    train_summary_op = tf.summary.merge([loss_summary, acc_summary, lr_summary])
+    train_summary_dir = os.path.join(out_dir, "summaries", "train")
+    train_summary_writer = tf.summary.FileWriter(train_summary_dir, tf.get_default_graph())
+
+    # Dev summaries
+    dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
+    dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
+    dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, tf.get_default_graph())
+
+    # Test summaries
+    test_summary_op = tf.summary.merge([loss_summary, acc_summary])
+    test_summary_dir = os.path.join(out_dir, "summaries", "test")
+    test_summary_writer = tf.summary.FileWriter(test_summary_dir, tf.get_default_graph())
+
+
+    # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
+    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
+    checkpoint_prefix = os.path.join(checkpoint_dir, model_type)
+    if not os.path.exists(checkpoint_dir):
+        os.makedirs(checkpoint_dir)
+
+
+    if model_type != 'cnn' and reuse_cnn:
+        # saver for reuse the CNN weight
+        reuse_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='VGG_NET_CNN')
+        original_saver = tf.train.Saver(reuse_vars)         # Pass the variables as a list
+
+    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
+
+    print("Starting training...")
+    total_start_time = time.time()
+    best_validation_accu = 0
+
+    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
+    with tf.Session() as sess:
+        sess.run(init_op)
+        if model_type != 'cnn' and reuse_cnn:
+            cnn_model_path = os.path.abspath(
+                                os.path.join(
+                                    os.path.curdir, log_path, ('cnn_'+str(subj_id)), 'checkpoints' ))
+            cnn_model_path = tf.train.latest_checkpoint(cnn_model_path)
+            print('-'*20)
+            print('Load cnn model weight for multi-frame model from {}'.format(cnn_model_path))
+            original_saver.restore(sess, cnn_model_path)
+
+        stop_count = 0  # count for earlystopping
+        for epoch in range(num_epochs):
+            print('-'*50)
+            # Train set
+            train_err = train_acc = train_batches = 0
+            start_time = time.time()
+            for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=False):
+                inputs, targets = batch
+                summary, _, pred, loss, acc = sess.run([train_summary_op, train_op, prediction, _loss, accuracy], 
+                    {input_var: inputs, target_var: targets, tf_is_training: True})
+                train_acc += acc
+                train_err += loss
+                train_batches += 1
+                train_summary_writer.add_summary(summary, sess.run(global_steps))
+
+            av_train_err = train_err / train_batches
+            av_train_acc = train_acc / train_batches
+
+            # Val set
+            summary, pred, av_val_err, av_val_acc = sess.run([dev_summary_op, prediction, _loss, accuracy],
+                    {input_var: X_val, target_var: y_val, tf_is_training: False})
+            dev_summary_writer.add_summary(summary, sess.run(global_steps))
+
+            
+            print("Epoch {} of {} took {:.3f}s".format(
+                epoch + 1, num_epochs, time.time() - start_time))
+            
+            fmt_str = "Train \tEpoch [{:d}/{:d}]  train_Loss: {:.4f}\ttrain_Acc: {:.2f}"
+            print_str = fmt_str.format(epoch + 1, num_epochs, av_train_err, av_train_acc*100)
+            print(print_str)
+
+            fmt_str = "Val \tEpoch [{:d}/{:d}]  val_Loss: {:.4f}\tval_Acc: {:.2f}"
+            print_str = fmt_str.format(epoch + 1, num_epochs, av_val_err, av_val_acc*100)
+            print(print_str)
+            
+            # Test set
+            summary, pred, av_test_err, av_test_acc = sess.run([test_summary_op, prediction, _loss, accuracy],
+                {input_var: X_test, target_var: y_test, tf_is_training: False})
+            test_summary_writer.add_summary(summary, sess.run(global_steps))
+            
+            fmt_str = "Test \tEpoch [{:d}/{:d}]  test_Loss: {:.4f}\ttest_Acc: {:.2f}"
+            print_str = fmt_str.format(epoch + 1, num_epochs, av_test_err, av_test_acc*100)
+            print(print_str)
+
+            if av_val_acc > best_validation_accu:   # early_stoping
+                stop_count = 0
+                eraly_stoping_epoch = epoch
+                best_validation_accu = av_val_acc
+                test_acc_val = av_test_acc
+                saver.save(sess, checkpoint_prefix, global_step=sess.run(global_steps))
+            else:
+                stop_count += 1
+                if stop_count >= 10: # stop training if val_acc dose not imporve for over 10 epochs
+                    break
+
+        train_batches = train_acc = 0
+        for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=False):
+            inputs, targets = batch
+            acc = sess.run(accuracy, {input_var: X_train, target_var: y_train, tf_is_training: False})
+            train_acc += acc
+            train_batches += 1
+
+        last_train_acc = train_acc / train_batches
+        
+        
+        last_val_acc = av_val_acc
+        last_test_acc = av_test_acc
+        print('-'*50)
+        print('Time in total:', time.time()-total_start_time)
+        print("Best validation accuracy:\t\t{:.2f} %".format(best_validation_accu * 100))
+        print("Test accuracy when got the best validation accuracy:\t\t{:.2f} %".format(test_acc_val * 100))
+        print('-'*50)
+        print("Last train accuracy:\t\t{:.2f} %".format(last_train_acc * 100))
+        print("Last validation accuracy:\t\t{:.2f} %".format(last_val_acc * 100))
+        print("Last test accuracy:\t\t\t\t{:.2f} %".format(last_test_acc * 100))
+        print('Early Stopping at epoch: {}'.format(eraly_stoping_epoch+1))
+
+    train_summary_writer.close()
+    dev_summary_writer.close()
+    test_summary_writer.close()
+    return [last_train_acc, best_validation_accu, test_acc_val, last_val_acc, last_test_acc]
+
+
+
+def train_all_model(num_epochs=3000):
+    nums_subject = 13
+    # Leave-Subject-Out cross validation
+    subj_nums = np.squeeze(scipy.io.loadmat('../SampleData/trials_subNums.mat')['subjectNum'])
+    fold_pairs = []
+    for i in np.unique(subj_nums):
+        ts = subj_nums == i
+        tr = np.squeeze(np.nonzero(np.bitwise_not(ts)))
+        ts = np.squeeze(np.nonzero(ts))
+        np.random.shuffle(tr)
+        np.random.shuffle(ts)
+        fold_pairs.append((tr, ts))
+
+
+    images_average, images_timewin, labels = load_or_generate_images(
+                                                file_path='../SampleData/', average_image=3)
+
+
+    print('*'*200)
+    acc_buf = []
+    for subj_id in range(nums_subject):
+        print('-'*100)
+        
+        if model_type == 'cnn':
+            print('The subjects', subj_id, '\t\t Training the ' + 'cnn' + ' Model...')
+            acc_temp = train(images_average, labels, fold_pairs[subj_id], 'cnn', 
+                                batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id,
+                                learning_rate_default=lrs['cnn'], Optimizer=optimizer, log_path=log_path)
+            acc_buf.append(acc_temp)
+            tf.reset_default_graph()
+            print('Done!')
+
+        else:
+            # whether to train cnn first, and load its weight for multi-frame model
+            if reuse_cnn_flag is True:
+                print('The subjects', subj_id, '\t\t Training the ' + 'cnn' + ' Model...')
+                acc_temp = train(images_average, labels, fold_pairs[subj_id], 'cnn', 
+                                    batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id,
+                                    learning_rate_default=lrs['cnn'], Optimizer=optimizer, log_path=log_path)
+                # acc_buf.append(acc_temp)
+                tf.reset_default_graph()
+                print('Done!')
+        
+            print('The subjects', subj_id, '\t\t Training the ' + model_type + ' Model...')
+            print('Load the CNN model weight for backbone...')
+            acc_temp = train(images_timewin, labels, fold_pairs[subj_id], model_type, 
+                            batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id, reuse_cnn=reuse_cnn_flag, 
+                            learning_rate_default=learning_rate, Optimizer=optimizer, log_path=log_path)
+                                
+            acc_buf.append(acc_temp)
+            tf.reset_default_graph()
+            print('Done!')
+        
+        # return
+
+    print('All folds for {} are done!'.format(model_type))
+    acc_buf = (np.array(acc_buf)).T
+    acc_mean = np.mean(acc_buf, axis=1).reshape(-1, 1)
+    acc_buf = np.concatenate([acc_buf, acc_mean], axis=1)
+    # the last column is the mean of current row
+    print('Last_train_acc:\t', acc_buf[0], '\tmean :', np.mean(acc_buf[0][-1]))
+    print('Best_val_acc:\t', acc_buf[1], '\tmean :', np.mean(acc_buf[1][-1]))
+    print('Earlystopping_test_acc:\t', acc_buf[2], '\tmean :', np.mean(acc_buf[2][-1]))
+    print('Last_val_acc:\t', acc_buf[3], '\tmean :', np.mean(acc_buf[3][-1]))
+    print('Last_test_acc:\t', acc_buf[4], '\tmean :', np.mean(acc_buf[4][-1]))
+    np.savetxt('./Accuracy_{}.csv'.format(model_type), acc_buf, fmt='%.4f', delimiter=',')
+
+
+if __name__ == '__main__':
+    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+    np.random.seed(2018)
+    tf.set_random_seed(2018)
+
+    train_all_model(num_epochs=num_epochs)