Diff of /SegNet/test.py [000000] .. [bc8010]

Switch to unified view

a b/SegNet/test.py
1
import os
2
import scipy
3
import tensorflow as tf
4
import tensorflow.contrib.slim as slim
5
6
import SegNetCMR
7
8
9
WORKING_DIR = os.getcwd()
10
TRAINING_DIR = os.path.join(WORKING_DIR, 'Data', 'Training')
11
TEST_DIR = os.path.join(WORKING_DIR, 'Data', 'Test')
12
13
ROOT_LOG_DIR = os.path.join(WORKING_DIR, 'Output')
14
RUN_NAME = "Run_new"
15
LOG_DIR = os.path.join(ROOT_LOG_DIR, RUN_NAME)
16
TRAIN_WRITER_DIR = os.path.join(LOG_DIR, 'Train')
17
TEST_WRITER_DIR = os.path.join(LOG_DIR, 'Test')
18
OUTPUT_IMAGE_DIR = os.path.join(LOG_DIR, 'Image_Output')
19
20
CHECKPOINT_FN = 'model.ckpt'
21
CHECKPOINT_FL = os.path.join(LOG_DIR, CHECKPOINT_FN)
22
23
24
BATCH_NORM_DECAY = 0.95 #Start off at 0.9, then increase.
25
MAX_STEPS = 1000
26
BATCH_SIZE = 5
27
SAVE_INTERVAL = 50
28
29
def main():
30
    test_data = SegNetCMR.GetData(TEST_DIR)
31
    g = tf.Graph()
32
33
    with g.as_default():
34
35
        images, labels, is_training = SegNetCMR.placeholder_inputs(batch_size=BATCH_SIZE)
36
37
        arg_scope = SegNetCMR.inference_scope(is_training=False, batch_norm_decay=BATCH_NORM_DECAY)
38
39
        with slim.arg_scope(arg_scope):
40
            logits = SegNetCMR.inference(images, class_inc_bg=2)
41
42
        accuracy = SegNetCMR.evaluation(logits=logits, labels=labels)
43
44
        init = tf.global_variables_initializer()
45
46
        saver = tf.train.Saver([x for x in tf.global_variables() if 'Adam' not in x.name])
47
48
        sm = tf.train.SessionManager()
49
50
        with sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=LOG_DIR) as sess:
51
52
            sess.run(tf.variables_initializer([x for x in tf.global_variables() if 'Adam' in x.name]))
53
54
            accuracy_all = 0
55
            now = 0
56
            epochs = 30
57
            for step in range(epochs):
58
                images_batch, labels_batch = test_data.next_batch_test(now, BATCH_SIZE)
59
60
                test_feed_dict = {images: images_batch,
61
                                  labels: labels_batch,
62
                                  is_training: False}
63
64
                mask, accuracy_batch = sess.run([logits, accuracy], feed_dict=test_feed_dict)
65
66
                for idx in range(BATCH_SIZE):
67
                    name = str(step*BATCH_SIZE+idx)
68
                    resize_image = scipy.misc.imresize(mask[idx, :, :, 1].astype(int), [768, 768], interp='cubic')
69
                    scipy.misc.imsave(os.path.join(OUTPUT_IMAGE_DIR, '{}.png'.format(name)), resize_image)
70
71
                now += BATCH_SIZE
72
                accuracy_all += accuracy_batch
73
74
            accuracy_mean = accuracy_all / epochs
75
            print('accuracy:{}'.format(accuracy_mean))
76
77
if __name__ == '__main__':
78
    main()