|
a |
|
b/RefineNet & SESNet/SESNet/test.py |
|
|
1 |
import time |
|
|
2 |
import os |
|
|
3 |
import numpy as np |
|
|
4 |
import tensorflow as tf |
|
|
5 |
import sys |
|
|
6 |
import cv2 |
|
|
7 |
|
|
|
8 |
slim = tf.contrib.slim |
|
|
9 |
sys.path.append(os.getcwd()) |
|
|
10 |
from nets import model as model |
|
|
11 |
from matplotlib import pyplot as plt |
|
|
12 |
from utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors |
|
|
13 |
|
|
|
14 |
tf.app.flags.DEFINE_string('model_type', 'refinenet', '') |
|
|
15 |
tf.app.flags.DEFINE_string('test_data_path', 'data/test.tfrecords', '') |
|
|
16 |
tf.app.flags.DEFINE_string('gpu_list', '1', '') |
|
|
17 |
tf.app.flags.DEFINE_integer('num_classes', 2, '') |
|
|
18 |
tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints/', '') |
|
|
19 |
tf.app.flags.DEFINE_string('result_path', 'result/', '') |
|
|
20 |
tf.app.flags.DEFINE_integer('test_size', 384, '') |
|
|
21 |
|
|
|
22 |
FLAGS = tf.app.flags.FLAGS |
|
|
23 |
|
|
|
24 |
def main(argv=None): |
|
|
25 |
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list |
|
|
26 |
|
|
|
27 |
if not os.path.exists(FLAGS.result_path): |
|
|
28 |
os.makedirs(FLAGS.result_path) |
|
|
29 |
|
|
|
30 |
filename_queue = tf.train.string_input_producer([FLAGS.test_data_path], num_epochs=1) |
|
|
31 |
image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue) |
|
|
32 |
|
|
|
33 |
image_batch_tensor = tf.expand_dims(image, axis=0) |
|
|
34 |
annotation_batch_tensor = tf.expand_dims(annotation, axis=0) |
|
|
35 |
|
|
|
36 |
input_image_shape = tf.shape(image_batch_tensor) |
|
|
37 |
image_height_width = input_image_shape[1:3] |
|
|
38 |
image_height_width_float = tf.to_float(image_height_width) |
|
|
39 |
image_height_width_multiple = tf.to_int32(tf.round(image_height_width_float / 32) * 32) |
|
|
40 |
|
|
|
41 |
image_batch_tensor = tf.image.resize_images(image_batch_tensor, image_height_width_multiple) |
|
|
42 |
|
|
|
43 |
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) |
|
|
44 |
logits = model.model(FLAGS.model_type, image_batch_tensor, is_training=False) |
|
|
45 |
pred = tf.argmax(logits, dimension=3) |
|
|
46 |
pred = tf.expand_dims(pred, 3) |
|
|
47 |
pred = tf.image.resize_bilinear(images=pred, size=image_height_width) |
|
|
48 |
annotation_batch_tensor = tf.image.resize_bilinear(images=annotation_batch_tensor, size=image_height_width) |
|
|
49 |
annotation_batch_tensor = tf.div(annotation_batch_tensor, 255) |
|
|
50 |
|
|
|
51 |
pred = tf.reshape(pred, [-1, ]) |
|
|
52 |
gt = tf.reshape(annotation_batch_tensor, [-1, ]) |
|
|
53 |
|
|
|
54 |
acc, acc_update_op = tf.contrib.metrics.streaming_accuracy(pred, gt) |
|
|
55 |
miou, miou_update_op = tf.contrib.metrics.streaming_mean_iou(pred, gt, num_classes=FLAGS.num_classes) |
|
|
56 |
|
|
|
57 |
with tf.get_default_graph().as_default(): |
|
|
58 |
global_vars_init_op = tf.global_variables_initializer() |
|
|
59 |
local_vars_init_op = tf.local_variables_initializer() |
|
|
60 |
init = tf.group(local_vars_init_op, global_vars_init_op) |
|
|
61 |
|
|
|
62 |
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) |
|
|
63 |
saver = tf.train.Saver(variable_averages.variables_to_restore()) |
|
|
64 |
|
|
|
65 |
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0) |
|
|
66 |
config = tf.ConfigProto(allow_soft_placement=True, |
|
|
67 |
log_device_placement=False, |
|
|
68 |
gpu_options=gpu_options) |
|
|
69 |
config.gpu_options.allow_growth = True |
|
|
70 |
|
|
|
71 |
with tf.Session(config=config) as sess: |
|
|
72 |
sess.run(init) |
|
|
73 |
ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) |
|
|
74 |
model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) |
|
|
75 |
print('Restore from {}'.format(model_path)) |
|
|
76 |
saver.restore(sess, model_path) |
|
|
77 |
|
|
|
78 |
coord = tf.train.Coordinator() |
|
|
79 |
threads = tf.train.start_queue_runners(coord=coord) |
|
|
80 |
|
|
|
81 |
for i in range(150): |
|
|
82 |
start = time.time() |
|
|
83 |
image_np, annotation_np, pred_np, tmp_acc, tmp_miou = sess.run( |
|
|
84 |
[image, annotation, pred, acc_update_op, miou_update_op]) |
|
|
85 |
_diff_time = time.time() - start |
|
|
86 |
print('{}: cost {:.0f}ms').format(i, _diff_time * 1000) |
|
|
87 |
# upsampled_predictions = pred_np.squeeze() |
|
|
88 |
# plt.subplot(131) |
|
|
89 |
# plt.imshow(image_np) |
|
|
90 |
# plt.subplot(132) |
|
|
91 |
# plt.imshow(annotation_np.squeeze(), cmap='gray') |
|
|
92 |
# plt.subplot(133) |
|
|
93 |
# plt.imshow(np.reshape(pred_np, (annotation_np.shape[0], annotation_np.shape[1])).squeeze(), cmap='gray') |
|
|
94 |
# plt.savefig(os.path.join(FLAGS.result_path, str(i) + '.png')) |
|
|
95 |
prediction = np.reshape(pred_np, (annotation_np.shape[0], annotation_np.shape[1])).squeeze() * 255 |
|
|
96 |
cv2.imwrite(os.path.join(FLAGS.result_path, str(i) + '.png'), prediction) |
|
|
97 |
print('Test Finished !') |
|
|
98 |
|
|
|
99 |
coord.request_stop() |
|
|
100 |
coord.join(threads) |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
if __name__ == '__main__': |
|
|
104 |
tf.app.run() |