a b/RefineNet & SESNet/SESNet/test.py
1
import time
2
import os
3
import numpy as np
4
import tensorflow as tf
5
import sys
6
import cv2
7
8
slim = tf.contrib.slim
9
sys.path.append(os.getcwd())
10
from nets import model as model
11
from matplotlib import pyplot as plt
12
from utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors
13
14
tf.app.flags.DEFINE_string('model_type', 'refinenet', '')
15
tf.app.flags.DEFINE_string('test_data_path', 'data/test.tfrecords', '')
16
tf.app.flags.DEFINE_string('gpu_list', '1', '')
17
tf.app.flags.DEFINE_integer('num_classes', 2, '')
18
tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints/', '')
19
tf.app.flags.DEFINE_string('result_path', 'result/', '')
20
tf.app.flags.DEFINE_integer('test_size', 384, '')
21
22
FLAGS = tf.app.flags.FLAGS
23
24
def main(argv=None):
25
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
26
27
    if not os.path.exists(FLAGS.result_path):
28
        os.makedirs(FLAGS.result_path)
29
30
    filename_queue = tf.train.string_input_producer([FLAGS.test_data_path], num_epochs=1)
31
    image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue)
32
33
    image_batch_tensor = tf.expand_dims(image, axis=0)
34
    annotation_batch_tensor = tf.expand_dims(annotation, axis=0)
35
36
    input_image_shape = tf.shape(image_batch_tensor)
37
    image_height_width = input_image_shape[1:3]
38
    image_height_width_float = tf.to_float(image_height_width)
39
    image_height_width_multiple = tf.to_int32(tf.round(image_height_width_float / 32) * 32)
40
41
    image_batch_tensor = tf.image.resize_images(image_batch_tensor, image_height_width_multiple)
42
43
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
44
    logits = model.model(FLAGS.model_type, image_batch_tensor, is_training=False)
45
    pred = tf.argmax(logits, dimension=3)
46
    pred = tf.expand_dims(pred, 3)
47
    pred = tf.image.resize_bilinear(images=pred, size=image_height_width)
48
    annotation_batch_tensor = tf.image.resize_bilinear(images=annotation_batch_tensor, size=image_height_width)
49
    annotation_batch_tensor = tf.div(annotation_batch_tensor, 255)
50
51
    pred = tf.reshape(pred, [-1, ])
52
    gt = tf.reshape(annotation_batch_tensor, [-1, ])
53
54
    acc, acc_update_op = tf.contrib.metrics.streaming_accuracy(pred, gt)
55
    miou, miou_update_op = tf.contrib.metrics.streaming_mean_iou(pred, gt, num_classes=FLAGS.num_classes)
56
57
    with tf.get_default_graph().as_default():
58
        global_vars_init_op = tf.global_variables_initializer()
59
        local_vars_init_op = tf.local_variables_initializer()
60
        init = tf.group(local_vars_init_op, global_vars_init_op)
61
62
        variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
63
        saver = tf.train.Saver(variable_averages.variables_to_restore())
64
65
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
66
        config = tf.ConfigProto(allow_soft_placement=True,
67
                                log_device_placement=False,
68
                                gpu_options=gpu_options)
69
        config.gpu_options.allow_growth = True
70
71
        with tf.Session(config=config) as sess:
72
            sess.run(init)
73
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
74
            model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
75
            print('Restore from {}'.format(model_path))
76
            saver.restore(sess, model_path)
77
78
            coord = tf.train.Coordinator()
79
            threads = tf.train.start_queue_runners(coord=coord)
80
81
            for i in range(150):
82
                start = time.time()
83
                image_np, annotation_np, pred_np, tmp_acc, tmp_miou = sess.run(
84
                    [image, annotation, pred, acc_update_op, miou_update_op])
85
                _diff_time = time.time() - start
86
                print('{}: cost {:.0f}ms').format(i, _diff_time * 1000)
87
                # upsampled_predictions = pred_np.squeeze()
88
                # plt.subplot(131)
89
                # plt.imshow(image_np)
90
                # plt.subplot(132)
91
                # plt.imshow(annotation_np.squeeze(), cmap='gray')
92
                # plt.subplot(133)
93
                # plt.imshow(np.reshape(pred_np, (annotation_np.shape[0], annotation_np.shape[1])).squeeze(), cmap='gray')
94
                # plt.savefig(os.path.join(FLAGS.result_path, str(i) + '.png'))
95
                prediction = np.reshape(pred_np, (annotation_np.shape[0], annotation_np.shape[1])).squeeze() * 255
96
                cv2.imwrite(os.path.join(FLAGS.result_path, str(i) + '.png'), prediction)
97
            print('Test Finished !')
98
99
    coord.request_stop()
100
    coord.join(threads)
101
102
103
if __name__ == '__main__':
104
    tf.app.run()