|
a |
|
b/EEGLearn/train.py |
|
|
1 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
2 |
## Created by: Yang Wang |
|
|
3 |
## School of Automation, Huazhong University of Science & Technology (HUST) |
|
|
4 |
## wangyang_sky@hust.edu.cn |
|
|
5 |
## Copyright (c) 2018 |
|
|
6 |
## |
|
|
7 |
## This source code is licensed under the MIT-style license found in the |
|
|
8 |
## LICENSE file in the root directory of this source tree |
|
|
9 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
10 |
|
|
|
11 |
#coding:utf-8 |
|
|
12 |
|
|
|
13 |
import os |
|
|
14 |
import tensorflow as tf |
|
|
15 |
import numpy as np |
|
|
16 |
import scipy.io |
|
|
17 |
import time |
|
|
18 |
import datetime |
|
|
19 |
|
|
|
20 |
from utils import reformatInput, load_or_generate_images, iterate_minibatches |
|
|
21 |
|
|
|
22 |
from model import build_cnn, build_convpool_conv1d, build_convpool_lstm, build_convpool_mix |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
timestamp = datetime.datetime.now().strftime('%Y-%m-%d.%H.%M') |
|
|
26 |
log_path = os.path.join("runs", timestamp) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
model_type = '1dconv' # ['1dconv', 'maxpool', 'lstm', 'mix', 'cnn'] |
|
|
30 |
log_path = log_path + '_' + model_type |
|
|
31 |
|
|
|
32 |
batch_size = 32 |
|
|
33 |
dropout_rate = 0.5 |
|
|
34 |
|
|
|
35 |
input_shape = [32, 32, 3] # 1024 |
|
|
36 |
nb_class = 4 |
|
|
37 |
n_colors = 3 |
|
|
38 |
|
|
|
39 |
# whether to train cnn first, and load its weight for multi-frame model |
|
|
40 |
reuse_cnn_flag = False |
|
|
41 |
|
|
|
42 |
# learning_rate for different models |
|
|
43 |
lrs = { |
|
|
44 |
'cnn': 1e-3, |
|
|
45 |
'1dconv': 1e-4, |
|
|
46 |
'lstm': 1e-4, |
|
|
47 |
'mix': 1e-4, |
|
|
48 |
} |
|
|
49 |
|
|
|
50 |
weight_decay = 1e-4 |
|
|
51 |
learning_rate = lrs[model_type] / 32 * batch_size |
|
|
52 |
optimizer = tf.train.AdamOptimizer |
|
|
53 |
|
|
|
54 |
num_epochs = 60 |
|
|
55 |
|
|
|
56 |
def train(images, labels, fold, model_type, batch_size, num_epochs, subj_id=0, reuse_cnn=False, |
|
|
57 |
dropout_rate=dropout_rate ,learning_rate_default=1e-3, Optimizer=tf.train.AdamOptimizer, log_path=log_path): |
|
|
58 |
""" |
|
|
59 |
A sample training function which loops over the training set and evaluates the network |
|
|
60 |
on the validation set after each epoch. Evaluates the network on the training set |
|
|
61 |
whenever the |
|
|
62 |
:param images: input images |
|
|
63 |
:param labels: target labels |
|
|
64 |
:param fold: tuple of (train, test) index numbers |
|
|
65 |
:param model_type: model type ('cnn', '1dconv', 'lstm', 'mix') |
|
|
66 |
:param batch_size: batch size for training |
|
|
67 |
:param num_epochs: number of epochs of dataset to go over for training |
|
|
68 |
:param subj_id: the id of fold for storing log and the best model |
|
|
69 |
:param reuse_cnn: whether to train cnn first, and load its weight for multi-frame model |
|
|
70 |
:return: none |
|
|
71 |
""" |
|
|
72 |
|
|
|
73 |
with tf.name_scope('Inputs'): |
|
|
74 |
input_var = tf.placeholder(tf.float32, [None, None, 32, 32, n_colors], name='X_inputs') |
|
|
75 |
target_var = tf.placeholder(tf.int64, [None], name='y_inputs') |
|
|
76 |
tf_is_training = tf.placeholder(tf.bool, None, name='is_training') |
|
|
77 |
|
|
|
78 |
num_classes = len(np.unique(labels)) |
|
|
79 |
(X_train, y_train), (X_val, y_val), (X_test, y_test) = reformatInput(images, labels, fold) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
print('Train set label and proportion:\t', np.unique(y_train, return_counts=True)) |
|
|
83 |
print('Val set label and proportion:\t', np.unique(y_val, return_counts=True)) |
|
|
84 |
print('Test set label and proportion:\t', np.unique(y_test, return_counts=True)) |
|
|
85 |
|
|
|
86 |
print('The shape of X_trian:\t', X_train.shape) |
|
|
87 |
print('The shape of X_val:\t', X_val.shape) |
|
|
88 |
print('The shape of X_test:\t', X_test.shape) |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
print("Building model and compiling functions...") |
|
|
92 |
if model_type == '1dconv': |
|
|
93 |
network = build_convpool_conv1d(input_var, num_classes, train=tf_is_training, |
|
|
94 |
dropout_rate=dropout_rate, name='CNN_Conv1d'+'_sbj'+str(subj_id)) |
|
|
95 |
elif model_type == 'lstm': |
|
|
96 |
network = build_convpool_lstm(input_var, num_classes, 100, train=tf_is_training, |
|
|
97 |
dropout_rate=dropout_rate, name='CNN_LSTM'+'_sbj'+str(subj_id)) |
|
|
98 |
elif model_type == 'mix': |
|
|
99 |
network = build_convpool_mix(input_var, num_classes, 100, train=tf_is_training, |
|
|
100 |
dropout_rate=dropout_rate, name='CNN_Mix'+'_sbj'+str(subj_id)) |
|
|
101 |
elif model_type == 'cnn': |
|
|
102 |
with tf.name_scope(name='CNN_layer'+'_fold'+str(subj_id)): |
|
|
103 |
network = build_cnn(input_var) # output shape [None, 4, 4, 128] |
|
|
104 |
convpool_flat = tf.reshape(network, [-1, 4*4*128]) |
|
|
105 |
h_fc1_drop1 = tf.layers.dropout(convpool_flat, rate=dropout_rate, training=tf_is_training, name='dropout_1') |
|
|
106 |
h_fc1 = tf.layers.dense(h_fc1_drop1, 256, activation=tf.nn.relu, name='fc_relu_256') |
|
|
107 |
h_fc1_drop2 = tf.layers.dropout(h_fc1, rate=dropout_rate, training=tf_is_training, name='dropout_2') |
|
|
108 |
network = tf.layers.dense(h_fc1_drop2, num_classes, name='fc_softmax') |
|
|
109 |
# the loss function contains the softmax activation |
|
|
110 |
else: |
|
|
111 |
raise ValueError("Model not supported ['1dconv', 'maxpool', 'lstm', 'mix', 'cnn']") |
|
|
112 |
|
|
|
113 |
Train_vars = tf.trainable_variables() |
|
|
114 |
|
|
|
115 |
prediction = network |
|
|
116 |
|
|
|
117 |
with tf.name_scope('Loss'): |
|
|
118 |
l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in Train_vars if 'kernel' in v.name]) |
|
|
119 |
ce_loss = tf.losses.sparse_softmax_cross_entropy(labels=target_var, logits=prediction) |
|
|
120 |
_loss = ce_loss + weight_decay*l2_loss |
|
|
121 |
|
|
|
122 |
# decay_steps learning rate decay |
|
|
123 |
decay_steps = 3*(len(y_train)//batch_size) # len(X_train)//batch_size the training steps for an epcoh |
|
|
124 |
with tf.name_scope('Optimizer'): |
|
|
125 |
# learning_rate = learning_rate_default * Decay_rate^(global_steps/decay_steps) |
|
|
126 |
global_steps = tf.Variable(0, name="global_step", trainable=False) |
|
|
127 |
learning_rate = tf.train.exponential_decay( # learning rate decay |
|
|
128 |
learning_rate_default, # Base learning rate. |
|
|
129 |
global_steps, |
|
|
130 |
decay_steps, |
|
|
131 |
0.95, # Decay rate. |
|
|
132 |
staircase=True) |
|
|
133 |
optimizer = Optimizer(learning_rate) # GradientDescentOptimizer AdamOptimizer |
|
|
134 |
train_op = optimizer.minimize(_loss, global_step=global_steps, var_list=Train_vars) |
|
|
135 |
|
|
|
136 |
with tf.name_scope('Accuracy'): |
|
|
137 |
prediction = tf.argmax(prediction, axis=1) |
|
|
138 |
correct_prediction = tf.equal(prediction, target_var) |
|
|
139 |
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) |
|
|
140 |
|
|
|
141 |
# Output directory for models and summaries |
|
|
142 |
# choose different path for different model and subject |
|
|
143 |
out_dir = os.path.abspath(os.path.join(os.path.curdir, log_path, (model_type+'_'+str(subj_id)) )) |
|
|
144 |
print("Writing to {}\n".format(out_dir)) |
|
|
145 |
|
|
|
146 |
# Summaries for loss, accuracy and learning_rate |
|
|
147 |
loss_summary = tf.summary.scalar('loss', _loss) |
|
|
148 |
acc_summary = tf.summary.scalar('train_acc', accuracy) |
|
|
149 |
lr_summary = tf.summary.scalar('learning_rate', learning_rate) |
|
|
150 |
|
|
|
151 |
# Train Summaries |
|
|
152 |
train_summary_op = tf.summary.merge([loss_summary, acc_summary, lr_summary]) |
|
|
153 |
train_summary_dir = os.path.join(out_dir, "summaries", "train") |
|
|
154 |
train_summary_writer = tf.summary.FileWriter(train_summary_dir, tf.get_default_graph()) |
|
|
155 |
|
|
|
156 |
# Dev summaries |
|
|
157 |
dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) |
|
|
158 |
dev_summary_dir = os.path.join(out_dir, "summaries", "dev") |
|
|
159 |
dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, tf.get_default_graph()) |
|
|
160 |
|
|
|
161 |
# Test summaries |
|
|
162 |
test_summary_op = tf.summary.merge([loss_summary, acc_summary]) |
|
|
163 |
test_summary_dir = os.path.join(out_dir, "summaries", "test") |
|
|
164 |
test_summary_writer = tf.summary.FileWriter(test_summary_dir, tf.get_default_graph()) |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
# Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it |
|
|
168 |
checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) |
|
|
169 |
checkpoint_prefix = os.path.join(checkpoint_dir, model_type) |
|
|
170 |
if not os.path.exists(checkpoint_dir): |
|
|
171 |
os.makedirs(checkpoint_dir) |
|
|
172 |
|
|
|
173 |
|
|
|
174 |
if model_type != 'cnn' and reuse_cnn: |
|
|
175 |
# saver for reuse the CNN weight |
|
|
176 |
reuse_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='VGG_NET_CNN') |
|
|
177 |
original_saver = tf.train.Saver(reuse_vars) # Pass the variables as a list |
|
|
178 |
|
|
|
179 |
saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) |
|
|
180 |
|
|
|
181 |
print("Starting training...") |
|
|
182 |
total_start_time = time.time() |
|
|
183 |
best_validation_accu = 0 |
|
|
184 |
|
|
|
185 |
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) |
|
|
186 |
with tf.Session() as sess: |
|
|
187 |
sess.run(init_op) |
|
|
188 |
if model_type != 'cnn' and reuse_cnn: |
|
|
189 |
cnn_model_path = os.path.abspath( |
|
|
190 |
os.path.join( |
|
|
191 |
os.path.curdir, log_path, ('cnn_'+str(subj_id)), 'checkpoints' )) |
|
|
192 |
cnn_model_path = tf.train.latest_checkpoint(cnn_model_path) |
|
|
193 |
print('-'*20) |
|
|
194 |
print('Load cnn model weight for multi-frame model from {}'.format(cnn_model_path)) |
|
|
195 |
original_saver.restore(sess, cnn_model_path) |
|
|
196 |
|
|
|
197 |
stop_count = 0 # count for earlystopping |
|
|
198 |
for epoch in range(num_epochs): |
|
|
199 |
print('-'*50) |
|
|
200 |
# Train set |
|
|
201 |
train_err = train_acc = train_batches = 0 |
|
|
202 |
start_time = time.time() |
|
|
203 |
for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=False): |
|
|
204 |
inputs, targets = batch |
|
|
205 |
summary, _, pred, loss, acc = sess.run([train_summary_op, train_op, prediction, _loss, accuracy], |
|
|
206 |
{input_var: inputs, target_var: targets, tf_is_training: True}) |
|
|
207 |
train_acc += acc |
|
|
208 |
train_err += loss |
|
|
209 |
train_batches += 1 |
|
|
210 |
train_summary_writer.add_summary(summary, sess.run(global_steps)) |
|
|
211 |
|
|
|
212 |
av_train_err = train_err / train_batches |
|
|
213 |
av_train_acc = train_acc / train_batches |
|
|
214 |
|
|
|
215 |
# Val set |
|
|
216 |
summary, pred, av_val_err, av_val_acc = sess.run([dev_summary_op, prediction, _loss, accuracy], |
|
|
217 |
{input_var: X_val, target_var: y_val, tf_is_training: False}) |
|
|
218 |
dev_summary_writer.add_summary(summary, sess.run(global_steps)) |
|
|
219 |
|
|
|
220 |
|
|
|
221 |
print("Epoch {} of {} took {:.3f}s".format( |
|
|
222 |
epoch + 1, num_epochs, time.time() - start_time)) |
|
|
223 |
|
|
|
224 |
fmt_str = "Train \tEpoch [{:d}/{:d}] train_Loss: {:.4f}\ttrain_Acc: {:.2f}" |
|
|
225 |
print_str = fmt_str.format(epoch + 1, num_epochs, av_train_err, av_train_acc*100) |
|
|
226 |
print(print_str) |
|
|
227 |
|
|
|
228 |
fmt_str = "Val \tEpoch [{:d}/{:d}] val_Loss: {:.4f}\tval_Acc: {:.2f}" |
|
|
229 |
print_str = fmt_str.format(epoch + 1, num_epochs, av_val_err, av_val_acc*100) |
|
|
230 |
print(print_str) |
|
|
231 |
|
|
|
232 |
# Test set |
|
|
233 |
summary, pred, av_test_err, av_test_acc = sess.run([test_summary_op, prediction, _loss, accuracy], |
|
|
234 |
{input_var: X_test, target_var: y_test, tf_is_training: False}) |
|
|
235 |
test_summary_writer.add_summary(summary, sess.run(global_steps)) |
|
|
236 |
|
|
|
237 |
fmt_str = "Test \tEpoch [{:d}/{:d}] test_Loss: {:.4f}\ttest_Acc: {:.2f}" |
|
|
238 |
print_str = fmt_str.format(epoch + 1, num_epochs, av_test_err, av_test_acc*100) |
|
|
239 |
print(print_str) |
|
|
240 |
|
|
|
241 |
if av_val_acc > best_validation_accu: # early_stoping |
|
|
242 |
stop_count = 0 |
|
|
243 |
eraly_stoping_epoch = epoch |
|
|
244 |
best_validation_accu = av_val_acc |
|
|
245 |
test_acc_val = av_test_acc |
|
|
246 |
saver.save(sess, checkpoint_prefix, global_step=sess.run(global_steps)) |
|
|
247 |
else: |
|
|
248 |
stop_count += 1 |
|
|
249 |
if stop_count >= 10: # stop training if val_acc dose not imporve for over 10 epochs |
|
|
250 |
break |
|
|
251 |
|
|
|
252 |
train_batches = train_acc = 0 |
|
|
253 |
for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=False): |
|
|
254 |
inputs, targets = batch |
|
|
255 |
acc = sess.run(accuracy, {input_var: X_train, target_var: y_train, tf_is_training: False}) |
|
|
256 |
train_acc += acc |
|
|
257 |
train_batches += 1 |
|
|
258 |
|
|
|
259 |
last_train_acc = train_acc / train_batches |
|
|
260 |
|
|
|
261 |
|
|
|
262 |
last_val_acc = av_val_acc |
|
|
263 |
last_test_acc = av_test_acc |
|
|
264 |
print('-'*50) |
|
|
265 |
print('Time in total:', time.time()-total_start_time) |
|
|
266 |
print("Best validation accuracy:\t\t{:.2f} %".format(best_validation_accu * 100)) |
|
|
267 |
print("Test accuracy when got the best validation accuracy:\t\t{:.2f} %".format(test_acc_val * 100)) |
|
|
268 |
print('-'*50) |
|
|
269 |
print("Last train accuracy:\t\t{:.2f} %".format(last_train_acc * 100)) |
|
|
270 |
print("Last validation accuracy:\t\t{:.2f} %".format(last_val_acc * 100)) |
|
|
271 |
print("Last test accuracy:\t\t\t\t{:.2f} %".format(last_test_acc * 100)) |
|
|
272 |
print('Early Stopping at epoch: {}'.format(eraly_stoping_epoch+1)) |
|
|
273 |
|
|
|
274 |
train_summary_writer.close() |
|
|
275 |
dev_summary_writer.close() |
|
|
276 |
test_summary_writer.close() |
|
|
277 |
return [last_train_acc, best_validation_accu, test_acc_val, last_val_acc, last_test_acc] |
|
|
278 |
|
|
|
279 |
|
|
|
280 |
|
|
|
281 |
def train_all_model(num_epochs=3000): |
|
|
282 |
nums_subject = 13 |
|
|
283 |
# Leave-Subject-Out cross validation |
|
|
284 |
subj_nums = np.squeeze(scipy.io.loadmat('../SampleData/trials_subNums.mat')['subjectNum']) |
|
|
285 |
fold_pairs = [] |
|
|
286 |
for i in np.unique(subj_nums): |
|
|
287 |
ts = subj_nums == i |
|
|
288 |
tr = np.squeeze(np.nonzero(np.bitwise_not(ts))) |
|
|
289 |
ts = np.squeeze(np.nonzero(ts)) |
|
|
290 |
np.random.shuffle(tr) |
|
|
291 |
np.random.shuffle(ts) |
|
|
292 |
fold_pairs.append((tr, ts)) |
|
|
293 |
|
|
|
294 |
|
|
|
295 |
images_average, images_timewin, labels = load_or_generate_images( |
|
|
296 |
file_path='../SampleData/', average_image=3) |
|
|
297 |
|
|
|
298 |
|
|
|
299 |
print('*'*200) |
|
|
300 |
acc_buf = [] |
|
|
301 |
for subj_id in range(nums_subject): |
|
|
302 |
print('-'*100) |
|
|
303 |
|
|
|
304 |
if model_type == 'cnn': |
|
|
305 |
print('The subjects', subj_id, '\t\t Training the ' + 'cnn' + ' Model...') |
|
|
306 |
acc_temp = train(images_average, labels, fold_pairs[subj_id], 'cnn', |
|
|
307 |
batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id, |
|
|
308 |
learning_rate_default=lrs['cnn'], Optimizer=optimizer, log_path=log_path) |
|
|
309 |
acc_buf.append(acc_temp) |
|
|
310 |
tf.reset_default_graph() |
|
|
311 |
print('Done!') |
|
|
312 |
|
|
|
313 |
else: |
|
|
314 |
# whether to train cnn first, and load its weight for multi-frame model |
|
|
315 |
if reuse_cnn_flag is True: |
|
|
316 |
print('The subjects', subj_id, '\t\t Training the ' + 'cnn' + ' Model...') |
|
|
317 |
acc_temp = train(images_average, labels, fold_pairs[subj_id], 'cnn', |
|
|
318 |
batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id, |
|
|
319 |
learning_rate_default=lrs['cnn'], Optimizer=optimizer, log_path=log_path) |
|
|
320 |
# acc_buf.append(acc_temp) |
|
|
321 |
tf.reset_default_graph() |
|
|
322 |
print('Done!') |
|
|
323 |
|
|
|
324 |
print('The subjects', subj_id, '\t\t Training the ' + model_type + ' Model...') |
|
|
325 |
print('Load the CNN model weight for backbone...') |
|
|
326 |
acc_temp = train(images_timewin, labels, fold_pairs[subj_id], model_type, |
|
|
327 |
batch_size=batch_size, num_epochs=num_epochs, subj_id=subj_id, reuse_cnn=reuse_cnn_flag, |
|
|
328 |
learning_rate_default=learning_rate, Optimizer=optimizer, log_path=log_path) |
|
|
329 |
|
|
|
330 |
acc_buf.append(acc_temp) |
|
|
331 |
tf.reset_default_graph() |
|
|
332 |
print('Done!') |
|
|
333 |
|
|
|
334 |
# return |
|
|
335 |
|
|
|
336 |
print('All folds for {} are done!'.format(model_type)) |
|
|
337 |
acc_buf = (np.array(acc_buf)).T |
|
|
338 |
acc_mean = np.mean(acc_buf, axis=1).reshape(-1, 1) |
|
|
339 |
acc_buf = np.concatenate([acc_buf, acc_mean], axis=1) |
|
|
340 |
# the last column is the mean of current row |
|
|
341 |
print('Last_train_acc:\t', acc_buf[0], '\tmean :', np.mean(acc_buf[0][-1])) |
|
|
342 |
print('Best_val_acc:\t', acc_buf[1], '\tmean :', np.mean(acc_buf[1][-1])) |
|
|
343 |
print('Earlystopping_test_acc:\t', acc_buf[2], '\tmean :', np.mean(acc_buf[2][-1])) |
|
|
344 |
print('Last_val_acc:\t', acc_buf[3], '\tmean :', np.mean(acc_buf[3][-1])) |
|
|
345 |
print('Last_test_acc:\t', acc_buf[4], '\tmean :', np.mean(acc_buf[4][-1])) |
|
|
346 |
np.savetxt('./Accuracy_{}.csv'.format(model_type), acc_buf, fmt='%.4f', delimiter=',') |
|
|
347 |
|
|
|
348 |
|
|
|
349 |
if __name__ == '__main__': |
|
|
350 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
351 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
352 |
np.random.seed(2018) |
|
|
353 |
tf.set_random_seed(2018) |
|
|
354 |
|
|
|
355 |
train_all_model(num_epochs=num_epochs) |