|
a |
|
b/model_train.py |
|
|
1 |
import os |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
|
|
|
5 |
import tensorflow as tf |
|
|
6 |
import config |
|
|
7 |
|
|
|
8 |
from model_utils import x, y, keep_prob |
|
|
9 |
from model_utils import input_img, reshape_op |
|
|
10 |
|
|
|
11 |
from model_utils import evaluate_log_loss, accuracy, evaluate_validation_set |
|
|
12 |
from model_utils import model_store_path, store_error_plots, evaluate_test_set |
|
|
13 |
from model_utils import high_error_increase, display_confusion_matrix_info |
|
|
14 |
from model_utils import get_specificity, get_sensitivity, validate_data_loaded |
|
|
15 |
from model import loss_function_with_logits, sparse_loss_with_logits |
|
|
16 |
from model_factory import ModelFactory |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
# Parameters used during training |
|
|
20 |
batch_size = config.BATCH_SIZE |
|
|
21 |
learning_rate = 0.001 |
|
|
22 |
training_iters = 101 |
|
|
23 |
save_step = 10 |
|
|
24 |
display_steps = 20 |
|
|
25 |
validaton_log_loss_incr_threshold = 0.1 |
|
|
26 |
last_errors = 2 |
|
|
27 |
tolerance = 20 |
|
|
28 |
dropout = 0.5 # Dropout, probability to keep units |
|
|
29 |
beta = 0.01 |
|
|
30 |
|
|
|
31 |
# Construct model |
|
|
32 |
factory = ModelFactory() |
|
|
33 |
model = factory.get_network_model() |
|
|
34 |
|
|
|
35 |
if not config.RESTORE: |
|
|
36 |
# Add tensors to collection stored in the model graph |
|
|
37 |
# definition |
|
|
38 |
tf.add_to_collection('vars', x) |
|
|
39 |
tf.add_to_collection('vars', y) |
|
|
40 |
tf.add_to_collection('vars', keep_prob) |
|
|
41 |
|
|
|
42 |
for weigth_var in model.weights(): |
|
|
43 |
tf.add_to_collection('vars', weigth_var) |
|
|
44 |
|
|
|
45 |
for bias_var in model.biases(): |
|
|
46 |
tf.add_to_collection('vars', bias_var) |
|
|
47 |
|
|
|
48 |
pred = model.conv_net(x, dropout) |
|
|
49 |
|
|
|
50 |
with tf.name_scope("cross_entropy"): |
|
|
51 |
# Define loss and optimizer |
|
|
52 |
cost = sparse_loss_with_logits(pred, y) |
|
|
53 |
|
|
|
54 |
# add l2 regularization on the weights on the fully connected layer |
|
|
55 |
# if term != 0 is returned |
|
|
56 |
regularizer = model.l2_regularizer() |
|
|
57 |
if regularizer != 0: |
|
|
58 |
print("Adding L2 regularization...") |
|
|
59 |
cost = tf.reduce_mean(cost + beta * regularizer) |
|
|
60 |
|
|
|
61 |
trainable_vars = tf.trainable_variables() |
|
|
62 |
|
|
|
63 |
with tf.name_scope("train"): |
|
|
64 |
gradients = tf.gradients(cost, trainable_vars) |
|
|
65 |
gradients = list(zip(gradients, trainable_vars)) |
|
|
66 |
optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate) |
|
|
67 |
train_op = optimizer.apply_gradients(grads_and_vars=gradients) |
|
|
68 |
|
|
|
69 |
# Add gradients to summary |
|
|
70 |
for gradient, var in gradients: |
|
|
71 |
tf.summary.histogram(var.name + '/gradient', gradient) |
|
|
72 |
|
|
|
73 |
# Add the variables we train to the summary |
|
|
74 |
for var in trainable_vars: |
|
|
75 |
tf.summary.histogram(var.name, var) |
|
|
76 |
|
|
|
77 |
# Predictions for the training, validation, and test data. |
|
|
78 |
softmax_prediction = tf.nn.softmax(pred, name='softmax_prediction') |
|
|
79 |
|
|
|
80 |
if not config.RESTORE: |
|
|
81 |
tf.add_to_collection('vars', cost) |
|
|
82 |
tf.add_to_collection('vars', softmax_prediction) |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
merged = tf.summary.merge_all() |
|
|
86 |
|
|
|
87 |
# ======= Training ======== |
|
|
88 |
data_loader = factory.get_data_loader() |
|
|
89 |
training_set = data_loader.get_training_set() |
|
|
90 |
validation_set = data_loader.get_validation_set() |
|
|
91 |
exact_tests = data_loader.get_exact_tests_set() |
|
|
92 |
model_out_dir = data_loader.results_out_dir() |
|
|
93 |
|
|
|
94 |
print('Validation examples count: ', validation_set.num_samples) |
|
|
95 |
print('Test examples count: ', exact_tests.num_samples) |
|
|
96 |
print('Model will be stored in: ', model_out_dir) |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
# Initializing the variables |
|
|
100 |
init = tf.global_variables_initializer() |
|
|
101 |
|
|
|
102 |
saver = tf.train.Saver() |
|
|
103 |
|
|
|
104 |
validation_errors = [] |
|
|
105 |
train_errors_per_epoch = [] |
|
|
106 |
best_validation_err = 1.0 |
|
|
107 |
best_validation_sensitivity = 0.0 |
|
|
108 |
|
|
|
109 |
# Add summary for log loss per epoch, accuracy and sensitivity |
|
|
110 |
with tf.name_scope("log_loss"): |
|
|
111 |
log_loss = tf.placeholder(tf.float32, name="log_loss_per_epoch") |
|
|
112 |
|
|
|
113 |
loss_summary = tf.summary.scalar("log_loss", log_loss) |
|
|
114 |
|
|
|
115 |
with tf.name_scope("sensitivity"): |
|
|
116 |
sensitivity = tf.placeholder(tf.float32, name="sensitivity_per_epoch") |
|
|
117 |
|
|
|
118 |
sensitivity_summary = tf.summary.scalar("sensitivity", sensitivity) |
|
|
119 |
|
|
|
120 |
with tf.name_scope("accuracy"): |
|
|
121 |
tf_accuracy = tf.placeholder(tf.float32, name="accuracy_per_epoch") |
|
|
122 |
|
|
|
123 |
accuracy_summary = tf.summary.scalar("accuracy", tf_accuracy) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def export_evaluation_summary(log_loss_value, |
|
|
127 |
accuracy_value, |
|
|
128 |
sensitivity_value, |
|
|
129 |
step, |
|
|
130 |
sess, |
|
|
131 |
writer): |
|
|
132 |
error_summary, acc_summary, sens_summary = sess.run( |
|
|
133 |
[loss_summary, accuracy_summary, sensitivity_summary], |
|
|
134 |
feed_dict={log_loss: log_loss_value, tf_accuracy: accuracy_value, |
|
|
135 |
sensitivity: sensitivity_value}) |
|
|
136 |
writer.add_summary(error_summary, global_step=step) |
|
|
137 |
writer.add_summary(acc_summary, global_step=step) |
|
|
138 |
writer.add_summary(sens_summary, global_step=step) |
|
|
139 |
writer.flush() |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
# Launch the graph |
|
|
143 |
with tf.Session() as sess: |
|
|
144 |
if not os.path.exists(config.SUMMARIES_DIR): |
|
|
145 |
os.makedirs(config.SUMMARIES_DIR) |
|
|
146 |
|
|
|
147 |
train_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, 'train')) |
|
|
148 |
validation_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, |
|
|
149 |
'validation')) |
|
|
150 |
|
|
|
151 |
sess.run(init) |
|
|
152 |
|
|
|
153 |
if config.RESTORE and \ |
|
|
154 |
os.path.exists(os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT + '.index')): |
|
|
155 |
|
|
|
156 |
saver.restore(sess, os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT)) |
|
|
157 |
print("Restoring model from last saved state: ", config.RESTORE_MODEL_CKPT) |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
# Add the model graph to TensorBoard |
|
|
161 |
if not config.RESTORE: |
|
|
162 |
train_writer.add_graph(sess.graph) |
|
|
163 |
|
|
|
164 |
for step in range(config.START_STEP, training_iters): |
|
|
165 |
train_pred = [] |
|
|
166 |
train_labels = [] |
|
|
167 |
|
|
|
168 |
for i in range(training_set.num_samples): |
|
|
169 |
batch_data, batch_labels = training_set.next_batch(batch_size) |
|
|
170 |
|
|
|
171 |
if not validate_data_loaded(batch_data, batch_labels): |
|
|
172 |
break |
|
|
173 |
|
|
|
174 |
reshaped = sess.run(reshape_op, feed_dict={input_img: np.stack(batch_data)}) |
|
|
175 |
feed_dict = {x: reshaped, y: batch_labels, keep_prob: dropout} |
|
|
176 |
|
|
|
177 |
if step % display_steps == 0: |
|
|
178 |
_, loss, predictions, summary = sess.run([train_op, cost, softmax_prediction, merged], |
|
|
179 |
feed_dict=feed_dict) |
|
|
180 |
|
|
|
181 |
try: |
|
|
182 |
train_writer.add_summary(summary, step + i) |
|
|
183 |
except Exception as e: |
|
|
184 |
print("Exeption raised during summary export. ", e) |
|
|
185 |
else: |
|
|
186 |
_, loss, predictions = sess.run([train_op, cost, softmax_prediction], |
|
|
187 |
feed_dict=feed_dict) |
|
|
188 |
|
|
|
189 |
train_pred.extend(predictions) |
|
|
190 |
train_labels.extend(batch_labels) |
|
|
191 |
|
|
|
192 |
train_writer.flush() |
|
|
193 |
|
|
|
194 |
if step % save_step == 0: |
|
|
195 |
print("Storing model snaphost...") |
|
|
196 |
saver.save(sess, model_store_path(model_out_dir, 'lungs' + str(step))) |
|
|
197 |
|
|
|
198 |
|
|
|
199 |
print("Train epoch {} finished. {} samples processed.".format( |
|
|
200 |
training_set.finished_epochs, len(train_pred))) |
|
|
201 |
|
|
|
202 |
if not len(train_pred): |
|
|
203 |
break |
|
|
204 |
|
|
|
205 |
train_acc_epoch = accuracy(np.stack(train_pred), np.stack(train_labels)) |
|
|
206 |
|
|
|
207 |
train_log_loss = evaluate_log_loss(train_pred, train_labels) |
|
|
208 |
|
|
|
209 |
print('Train log loss error {}.'.format(train_log_loss)) |
|
|
210 |
print('Train set accuracy {}.'.format(train_acc_epoch)) |
|
|
211 |
print('Train set confusion matrix.') |
|
|
212 |
confusion_matrix = display_confusion_matrix_info(train_labels, train_pred) |
|
|
213 |
train_sensitivity = get_sensitivity(confusion_matrix) |
|
|
214 |
train_specificity = get_specificity(confusion_matrix) |
|
|
215 |
print('Test data sensitivity {} and specificity {}'.format( |
|
|
216 |
train_sensitivity, train_specificity)) |
|
|
217 |
|
|
|
218 |
export_evaluation_summary(train_log_loss, |
|
|
219 |
train_acc_epoch, |
|
|
220 |
train_sensitivity, |
|
|
221 |
step, sess, train_writer) |
|
|
222 |
|
|
|
223 |
print('Evaluate validation set') |
|
|
224 |
validation_acc, validation_log_loss, val_sensitivity, val_specificity = evaluate_validation_set(sess, |
|
|
225 |
validation_set, |
|
|
226 |
softmax_prediction, |
|
|
227 |
x, |
|
|
228 |
batch_size) |
|
|
229 |
if not validation_log_loss: |
|
|
230 |
break |
|
|
231 |
export_evaluation_summary(validation_log_loss, |
|
|
232 |
validation_acc, |
|
|
233 |
val_sensitivity, |
|
|
234 |
step, sess, validation_writer) |
|
|
235 |
|
|
|
236 |
print('Validation accuracy: %.1f%%' % validation_acc) |
|
|
237 |
print('Log loss overall validation samples: {}.'.format( |
|
|
238 |
validation_log_loss)) |
|
|
239 |
print('Validation set sensitivity {} and specificity {}'.format( |
|
|
240 |
val_sensitivity, val_specificity)) |
|
|
241 |
|
|
|
242 |
if validation_log_loss < best_validation_err and val_sensitivity > best_validation_sensitivity: |
|
|
243 |
best_validation_err = validation_log_loss |
|
|
244 |
best_validation_sensitivity = val_sensitivity |
|
|
245 |
print("Storing model snaphost with best validation error {} and sensitivity {} ".format( |
|
|
246 |
best_validation_err, best_validation_sensitivity)) |
|
|
247 |
if step % save_step != 0: |
|
|
248 |
saver.save(sess, model_store_path(model_out_dir, 'best_err' + str(step))) |
|
|
249 |
|
|
|
250 |
if validation_log_loss < 0.1: |
|
|
251 |
print("Low enough log loss validation error, terminate!") |
|
|
252 |
break; |
|
|
253 |
|
|
|
254 |
if high_error_increase(validation_errors[-last_errors:], |
|
|
255 |
validation_log_loss, |
|
|
256 |
last_errors, |
|
|
257 |
validaton_log_loss_incr_threshold): |
|
|
258 |
if tolerance and train_log_loss <= train_errors_per_epoch[-1]: |
|
|
259 |
print("Train error still decreases, continue...") |
|
|
260 |
tolerance -= 1 |
|
|
261 |
validation_errors.append(validation_log_loss) |
|
|
262 |
train_errors_per_epoch.append(train_log_loss) |
|
|
263 |
continue |
|
|
264 |
|
|
|
265 |
print("Validation log loss has increased more than the allowed threshold", |
|
|
266 |
" for the past iterations, terminate!") |
|
|
267 |
print("Last iterations: ", validation_errors[-last_errors:]) |
|
|
268 |
print("Current validation error: ", validation_log_loss) |
|
|
269 |
break |
|
|
270 |
|
|
|
271 |
validation_errors.append(validation_log_loss) |
|
|
272 |
train_errors_per_epoch.append(train_log_loss) |
|
|
273 |
|
|
|
274 |
train_writer.close() |
|
|
275 |
validation_writer.close() |
|
|
276 |
|
|
|
277 |
saver.save(sess, model_store_path(model_out_dir, 'last')) |
|
|
278 |
print("Model saved...") |
|
|
279 |
store_error_plots(validation_errors, train_errors_per_epoch) |
|
|
280 |
|
|
|
281 |
|
|
|
282 |
# ============= REAL TEST DATA EVALUATION ===================== |
|
|
283 |
evaluate_test_set(sess, |
|
|
284 |
exact_tests, |
|
|
285 |
softmax_prediction, |
|
|
286 |
x) |