|
a |
|
b/SegNet/train.py |
|
|
1 |
import os |
|
|
2 |
|
|
|
3 |
import tensorflow as tf |
|
|
4 |
import tensorflow.contrib.slim as slim |
|
|
5 |
|
|
|
6 |
import SegNetCMR |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
WORKING_DIR = os.getcwd() |
|
|
10 |
TRAINING_DIR = os.path.join(WORKING_DIR, 'Data', 'Training') |
|
|
11 |
TEST_DIR = os.path.join(WORKING_DIR, 'Data', 'Test') |
|
|
12 |
|
|
|
13 |
ROOT_LOG_DIR = os.path.join(WORKING_DIR, 'Output') |
|
|
14 |
RUN_NAME = "Run_double" |
|
|
15 |
LOG_DIR = os.path.join(ROOT_LOG_DIR, RUN_NAME) |
|
|
16 |
TRAIN_WRITER_DIR = os.path.join(LOG_DIR, 'Train') |
|
|
17 |
TEST_WRITER_DIR = os.path.join(LOG_DIR, 'Test') |
|
|
18 |
|
|
|
19 |
CHECKPOINT_FN = 'model.ckpt' |
|
|
20 |
CHECKPOINT_FL = os.path.join(LOG_DIR, CHECKPOINT_FN) |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
BATCH_NORM_DECAY = 0.95 #Start off at 0.9, then increase. |
|
|
24 |
MAX_STEPS = 10000 |
|
|
25 |
BATCH_SIZE = 3 |
|
|
26 |
SAVE_INTERVAL = 50 |
|
|
27 |
|
|
|
28 |
TEST = True |
|
|
29 |
def main(): |
|
|
30 |
training_data = SegNetCMR.GetData(TRAINING_DIR) |
|
|
31 |
test_data = SegNetCMR.GetData(TEST_DIR) |
|
|
32 |
|
|
|
33 |
g = tf.Graph() |
|
|
34 |
|
|
|
35 |
with g.as_default(): |
|
|
36 |
|
|
|
37 |
images, labels, is_training = SegNetCMR.placeholder_inputs(batch_size=BATCH_SIZE) |
|
|
38 |
|
|
|
39 |
arg_scope = SegNetCMR.inference_scope(is_training=True, batch_norm_decay=BATCH_NORM_DECAY) |
|
|
40 |
|
|
|
41 |
with slim.arg_scope(arg_scope): |
|
|
42 |
logits = SegNetCMR.inference(images, class_inc_bg=2) |
|
|
43 |
|
|
|
44 |
SegNetCMR.add_output_images(images=images, logits=logits, labels=labels) |
|
|
45 |
|
|
|
46 |
loss = SegNetCMR.loss_calc(logits=logits, labels=labels) |
|
|
47 |
|
|
|
48 |
train_op, global_step = SegNetCMR.training(loss=loss, learning_rate=1e-04) |
|
|
49 |
|
|
|
50 |
accuracy = SegNetCMR.evaluation(logits=logits, labels=labels) |
|
|
51 |
|
|
|
52 |
summary = tf.summary.merge_all() |
|
|
53 |
|
|
|
54 |
init = tf.global_variables_initializer() |
|
|
55 |
|
|
|
56 |
saver = tf.train.Saver([x for x in tf.global_variables() if 'Adam' not in x.name]) |
|
|
57 |
|
|
|
58 |
sm = tf.train.SessionManager() |
|
|
59 |
|
|
|
60 |
with sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=LOG_DIR) as sess: |
|
|
61 |
|
|
|
62 |
sess.run(tf.variables_initializer([x for x in tf.global_variables() if 'Adam' in x.name])) |
|
|
63 |
|
|
|
64 |
train_writer = tf.summary.FileWriter(TRAIN_WRITER_DIR, sess.graph) |
|
|
65 |
test_writer = tf.summary.FileWriter(TEST_WRITER_DIR) |
|
|
66 |
|
|
|
67 |
global_step_value, = sess.run([global_step]) |
|
|
68 |
|
|
|
69 |
print("Last trained iteration was: ", global_step_value) |
|
|
70 |
|
|
|
71 |
for step in range(global_step_value+1, global_step_value+MAX_STEPS+1): |
|
|
72 |
|
|
|
73 |
print("Iteration: ", step) |
|
|
74 |
|
|
|
75 |
images_batch, labels_batch = training_data.next_batch(BATCH_SIZE) |
|
|
76 |
|
|
|
77 |
train_feed_dict = {images: images_batch, |
|
|
78 |
labels: labels_batch, |
|
|
79 |
is_training: True} |
|
|
80 |
|
|
|
81 |
_, train_loss_value, train_accuracy_value, train_summary_str = sess.run([train_op, loss, accuracy, summary], feed_dict=train_feed_dict) |
|
|
82 |
|
|
|
83 |
if step % SAVE_INTERVAL == 0 and TEST: |
|
|
84 |
|
|
|
85 |
print("Train Loss: ", train_loss_value) |
|
|
86 |
print("Train accuracy: ", train_accuracy_value) |
|
|
87 |
train_writer.add_summary(train_summary_str, step) |
|
|
88 |
train_writer.flush() |
|
|
89 |
|
|
|
90 |
images_batch, labels_batch = test_data.next_batch(BATCH_SIZE) |
|
|
91 |
|
|
|
92 |
test_feed_dict = {images: images_batch, |
|
|
93 |
labels: labels_batch, |
|
|
94 |
is_training: False} |
|
|
95 |
|
|
|
96 |
test_loss_value, test_accuracy_value, test_summary_str = sess.run([loss, accuracy, summary], feed_dict=test_feed_dict) |
|
|
97 |
|
|
|
98 |
print("Test Loss: ", test_loss_value) |
|
|
99 |
print("Test accuracy: ", test_accuracy_value) |
|
|
100 |
test_writer.add_summary(test_summary_str, step) |
|
|
101 |
test_writer.flush() |
|
|
102 |
|
|
|
103 |
saver.save(sess, CHECKPOINT_FL, global_step=step) |
|
|
104 |
print("Session Saved") |
|
|
105 |
print("================") |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
if __name__ == '__main__': |
|
|
109 |
main() |