|
a |
|
b/SegNet/test.py |
|
|
1 |
import os |
|
|
2 |
import scipy |
|
|
3 |
import tensorflow as tf |
|
|
4 |
import tensorflow.contrib.slim as slim |
|
|
5 |
|
|
|
6 |
import SegNetCMR |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
WORKING_DIR = os.getcwd() |
|
|
10 |
TRAINING_DIR = os.path.join(WORKING_DIR, 'Data', 'Training') |
|
|
11 |
TEST_DIR = os.path.join(WORKING_DIR, 'Data', 'Test') |
|
|
12 |
|
|
|
13 |
ROOT_LOG_DIR = os.path.join(WORKING_DIR, 'Output') |
|
|
14 |
RUN_NAME = "Run_new" |
|
|
15 |
LOG_DIR = os.path.join(ROOT_LOG_DIR, RUN_NAME) |
|
|
16 |
TRAIN_WRITER_DIR = os.path.join(LOG_DIR, 'Train') |
|
|
17 |
TEST_WRITER_DIR = os.path.join(LOG_DIR, 'Test') |
|
|
18 |
OUTPUT_IMAGE_DIR = os.path.join(LOG_DIR, 'Image_Output') |
|
|
19 |
|
|
|
20 |
CHECKPOINT_FN = 'model.ckpt' |
|
|
21 |
CHECKPOINT_FL = os.path.join(LOG_DIR, CHECKPOINT_FN) |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
BATCH_NORM_DECAY = 0.95 #Start off at 0.9, then increase. |
|
|
25 |
MAX_STEPS = 1000 |
|
|
26 |
BATCH_SIZE = 5 |
|
|
27 |
SAVE_INTERVAL = 50 |
|
|
28 |
|
|
|
29 |
def main(): |
|
|
30 |
test_data = SegNetCMR.GetData(TEST_DIR) |
|
|
31 |
g = tf.Graph() |
|
|
32 |
|
|
|
33 |
with g.as_default(): |
|
|
34 |
|
|
|
35 |
images, labels, is_training = SegNetCMR.placeholder_inputs(batch_size=BATCH_SIZE) |
|
|
36 |
|
|
|
37 |
arg_scope = SegNetCMR.inference_scope(is_training=False, batch_norm_decay=BATCH_NORM_DECAY) |
|
|
38 |
|
|
|
39 |
with slim.arg_scope(arg_scope): |
|
|
40 |
logits = SegNetCMR.inference(images, class_inc_bg=2) |
|
|
41 |
|
|
|
42 |
accuracy = SegNetCMR.evaluation(logits=logits, labels=labels) |
|
|
43 |
|
|
|
44 |
init = tf.global_variables_initializer() |
|
|
45 |
|
|
|
46 |
saver = tf.train.Saver([x for x in tf.global_variables() if 'Adam' not in x.name]) |
|
|
47 |
|
|
|
48 |
sm = tf.train.SessionManager() |
|
|
49 |
|
|
|
50 |
with sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=LOG_DIR) as sess: |
|
|
51 |
|
|
|
52 |
sess.run(tf.variables_initializer([x for x in tf.global_variables() if 'Adam' in x.name])) |
|
|
53 |
|
|
|
54 |
accuracy_all = 0 |
|
|
55 |
now = 0 |
|
|
56 |
epochs = 30 |
|
|
57 |
for step in range(epochs): |
|
|
58 |
images_batch, labels_batch = test_data.next_batch_test(now, BATCH_SIZE) |
|
|
59 |
|
|
|
60 |
test_feed_dict = {images: images_batch, |
|
|
61 |
labels: labels_batch, |
|
|
62 |
is_training: False} |
|
|
63 |
|
|
|
64 |
mask, accuracy_batch = sess.run([logits, accuracy], feed_dict=test_feed_dict) |
|
|
65 |
|
|
|
66 |
for idx in range(BATCH_SIZE): |
|
|
67 |
name = str(step*BATCH_SIZE+idx) |
|
|
68 |
resize_image = scipy.misc.imresize(mask[idx, :, :, 1].astype(int), [768, 768], interp='cubic') |
|
|
69 |
scipy.misc.imsave(os.path.join(OUTPUT_IMAGE_DIR, '{}.png'.format(name)), resize_image) |
|
|
70 |
|
|
|
71 |
now += BATCH_SIZE |
|
|
72 |
accuracy_all += accuracy_batch |
|
|
73 |
|
|
|
74 |
accuracy_mean = accuracy_all / epochs |
|
|
75 |
print('accuracy:{}'.format(accuracy_mean)) |
|
|
76 |
|
|
|
77 |
if __name__ == '__main__': |
|
|
78 |
main() |