a b/model_train.py
1
import os
2
import pandas as pd
3
import numpy as np
4
5
import tensorflow as tf
6
import config
7
8
from model_utils import x, y, keep_prob
9
from model_utils import input_img, reshape_op
10
11
from model_utils import evaluate_log_loss, accuracy, evaluate_validation_set
12
from model_utils import model_store_path, store_error_plots, evaluate_test_set
13
from model_utils import high_error_increase, display_confusion_matrix_info
14
from model_utils import get_specificity, get_sensitivity, validate_data_loaded
15
from model import loss_function_with_logits, sparse_loss_with_logits
16
from model_factory import ModelFactory
17
18
19
# Parameters used during training
20
batch_size = config.BATCH_SIZE
21
learning_rate = 0.001
22
training_iters = 101
23
save_step = 10
24
display_steps = 20
25
validaton_log_loss_incr_threshold = 0.1
26
last_errors = 2
27
tolerance = 20
28
dropout = 0.5 # Dropout, probability to keep units
29
beta = 0.01
30
31
# Construct model
32
factory = ModelFactory()
33
model = factory.get_network_model()
34
35
if not config.RESTORE:
36
    # Add tensors to collection stored in the model graph
37
    # definition
38
    tf.add_to_collection('vars', x)
39
    tf.add_to_collection('vars', y)
40
    tf.add_to_collection('vars', keep_prob)
41
42
    for weigth_var in model.weights():
43
        tf.add_to_collection('vars', weigth_var)
44
45
    for bias_var in model.biases():
46
        tf.add_to_collection('vars', bias_var)
47
48
pred = model.conv_net(x, dropout)
49
50
with tf.name_scope("cross_entropy"):
51
    # Define loss and optimizer
52
    cost = sparse_loss_with_logits(pred, y)
53
    
54
    # add l2 regularization on the weights on the fully connected layer
55
    # if term != 0 is returned
56
    regularizer = model.l2_regularizer()
57
    if regularizer != 0:
58
        print("Adding L2 regularization...")
59
        cost = tf.reduce_mean(cost + beta * regularizer)
60
61
trainable_vars = tf.trainable_variables()
62
63
with tf.name_scope("train"):
64
    gradients = tf.gradients(cost, trainable_vars)
65
    gradients = list(zip(gradients, trainable_vars))
66
    optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
67
    train_op = optimizer.apply_gradients(grads_and_vars=gradients)
68
69
# Add gradients to summary  
70
for gradient, var in gradients:
71
    tf.summary.histogram(var.name + '/gradient', gradient)
72
73
# Add the variables we train to the summary  
74
for var in trainable_vars:
75
    tf.summary.histogram(var.name, var)
76
77
# Predictions for the training, validation, and test data.
78
softmax_prediction = tf.nn.softmax(pred, name='softmax_prediction')
79
80
if not config.RESTORE:
81
    tf.add_to_collection('vars', cost)
82
    tf.add_to_collection('vars', softmax_prediction)
83
84
85
merged = tf.summary.merge_all()
86
87
# ======= Training ========
88
data_loader = factory.get_data_loader()
89
training_set = data_loader.get_training_set()
90
validation_set = data_loader.get_validation_set()
91
exact_tests = data_loader.get_exact_tests_set()
92
model_out_dir = data_loader.results_out_dir()
93
94
print('Validation examples count: ', validation_set.num_samples)
95
print('Test examples count: ', exact_tests.num_samples)
96
print('Model will be stored in: ', model_out_dir)
97
98
99
# Initializing the variables
100
init = tf.global_variables_initializer()
101
102
saver = tf.train.Saver()
103
104
validation_errors = []
105
train_errors_per_epoch = []
106
best_validation_err = 1.0
107
best_validation_sensitivity = 0.0
108
109
# Add summary for log loss per epoch, accuracy and sensitivity
110
with tf.name_scope("log_loss"):
111
    log_loss = tf.placeholder(tf.float32, name="log_loss_per_epoch")
112
    
113
loss_summary = tf.summary.scalar("log_loss", log_loss)
114
115
with tf.name_scope("sensitivity"):
116
    sensitivity = tf.placeholder(tf.float32, name="sensitivity_per_epoch")
117
118
sensitivity_summary = tf.summary.scalar("sensitivity", sensitivity)
119
120
with tf.name_scope("accuracy"):
121
    tf_accuracy = tf.placeholder(tf.float32, name="accuracy_per_epoch")
122
123
accuracy_summary = tf.summary.scalar("accuracy", tf_accuracy)
124
125
126
def export_evaluation_summary(log_loss_value, 
127
                              accuracy_value, 
128
                              sensitivity_value, 
129
                              step,
130
                              sess,
131
                              writer):
132
    error_summary, acc_summary, sens_summary = sess.run(
133
        [loss_summary, accuracy_summary, sensitivity_summary],
134
        feed_dict={log_loss: log_loss_value, tf_accuracy: accuracy_value, 
135
                   sensitivity: sensitivity_value})
136
    writer.add_summary(error_summary, global_step=step)
137
    writer.add_summary(acc_summary, global_step=step)
138
    writer.add_summary(sens_summary, global_step=step)
139
    writer.flush()
140
141
142
# Launch the graph
143
with tf.Session() as sess:
144
    if not os.path.exists(config.SUMMARIES_DIR):
145
        os.makedirs(config.SUMMARIES_DIR)
146
        
147
    train_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, 'train'))
148
    validation_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, 
149
                                                           'validation'))
150
151
    sess.run(init)
152
153
    if config.RESTORE and \
154
        os.path.exists(os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT + '.index')):
155
        
156
        saver.restore(sess, os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT))
157
        print("Restoring model from last saved state: ", config.RESTORE_MODEL_CKPT)
158
159
160
    # Add the model graph to TensorBoard
161
    if not config.RESTORE:
162
        train_writer.add_graph(sess.graph)
163
164
    for step in range(config.START_STEP, training_iters):
165
        train_pred = []
166
        train_labels = []
167
168
        for i in range(training_set.num_samples):
169
            batch_data, batch_labels = training_set.next_batch(batch_size)
170
171
            if not validate_data_loaded(batch_data, batch_labels):
172
                break
173
174
            reshaped = sess.run(reshape_op, feed_dict={input_img: np.stack(batch_data)})
175
            feed_dict = {x: reshaped, y: batch_labels, keep_prob: dropout}
176
177
            if step % display_steps == 0:
178
                _, loss, predictions, summary = sess.run([train_op, cost, softmax_prediction, merged], 
179
                                                          feed_dict=feed_dict)
180
181
                try:
182
                    train_writer.add_summary(summary, step + i)
183
                except Exception as e:
184
                    print("Exeption raised during summary export. ", e)
185
            else:
186
                _, loss, predictions = sess.run([train_op, cost, softmax_prediction], 
187
                                                 feed_dict=feed_dict)
188
189
            train_pred.extend(predictions)
190
            train_labels.extend(batch_labels)
191
192
        train_writer.flush()
193
194
        if step % save_step == 0:
195
            print("Storing model snaphost...")
196
            saver.save(sess, model_store_path(model_out_dir, 'lungs' + str(step)))
197
198
        
199
        print("Train epoch {} finished. {} samples processed.".format(
200
            training_set.finished_epochs, len(train_pred)))
201
202
        if not len(train_pred):
203
            break
204
            
205
        train_acc_epoch = accuracy(np.stack(train_pred), np.stack(train_labels))
206
207
        train_log_loss = evaluate_log_loss(train_pred, train_labels)
208
    
209
        print('Train log loss error {}.'.format(train_log_loss))
210
        print('Train set accuracy {}.'.format(train_acc_epoch))
211
        print('Train set confusion matrix.')
212
        confusion_matrix = display_confusion_matrix_info(train_labels, train_pred)
213
        train_sensitivity = get_sensitivity(confusion_matrix)
214
        train_specificity = get_specificity(confusion_matrix)
215
        print('Test data sensitivity {} and specificity {}'.format(
216
            train_sensitivity, train_specificity))
217
218
        export_evaluation_summary(train_log_loss, 
219
                                  train_acc_epoch, 
220
                                  train_sensitivity, 
221
                                  step, sess, train_writer)
222
223
        print('Evaluate validation set')
224
        validation_acc, validation_log_loss, val_sensitivity, val_specificity = evaluate_validation_set(sess, 
225
                                                                                                        validation_set,
226
                                                                                                        softmax_prediction,
227
                                                                                                        x,
228
                                                                                                        batch_size)
229
        if not validation_log_loss:
230
            break
231
        export_evaluation_summary(validation_log_loss, 
232
                                  validation_acc, 
233
                                  val_sensitivity, 
234
                                  step, sess, validation_writer)
235
236
        print('Validation accuracy: %.1f%%' % validation_acc)
237
        print('Log loss overall validation samples: {}.'.format(
238
            validation_log_loss))
239
        print('Validation set sensitivity {} and specificity {}'.format(
240
            val_sensitivity, val_specificity))
241
242
        if validation_log_loss < best_validation_err and val_sensitivity > best_validation_sensitivity:
243
            best_validation_err = validation_log_loss
244
            best_validation_sensitivity = val_sensitivity
245
            print("Storing model snaphost with best validation error {} and sensitivity {} ".format(
246
                best_validation_err, best_validation_sensitivity))
247
            if step % save_step != 0:
248
                saver.save(sess, model_store_path(model_out_dir, 'best_err' + str(step)))
249
250
        if validation_log_loss < 0.1:
251
            print("Low enough log loss validation error, terminate!")
252
            break;
253
254
        if high_error_increase(validation_errors[-last_errors:], 
255
                               validation_log_loss,
256
                               last_errors,
257
                               validaton_log_loss_incr_threshold):
258
            if tolerance and train_log_loss <= train_errors_per_epoch[-1]:
259
                print("Train error still decreases, continue...")
260
                tolerance -= 1
261
                validation_errors.append(validation_log_loss)
262
                train_errors_per_epoch.append(train_log_loss)
263
                continue
264
265
            print("Validation log loss has increased more than the allowed threshold",
266
                  " for the past iterations, terminate!")
267
            print("Last iterations: ", validation_errors[-last_errors:])
268
            print("Current validation error: ", validation_log_loss)
269
            break
270
271
        validation_errors.append(validation_log_loss)
272
        train_errors_per_epoch.append(train_log_loss)
273
274
    train_writer.close()
275
    validation_writer.close()
276
277
    saver.save(sess, model_store_path(model_out_dir, 'last'))
278
    print("Model saved...")
279
    store_error_plots(validation_errors, train_errors_per_epoch)
280
281
282
    # ============= REAL TEST DATA EVALUATION =====================
283
    evaluate_test_set(sess,
284
                      exact_tests,
285
                      softmax_prediction,
286
                      x)