|
a |
|
b/test_image_all_720um.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import numpy as np |
|
|
3 |
import cv2 |
|
|
4 |
from pathlib import Path |
|
|
5 |
import scipy.misc |
|
|
6 |
import Network |
|
|
7 |
import os |
|
|
8 |
|
|
|
9 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
10 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # only uses GPU 1 |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
results_dir = './results/' |
|
|
14 |
results_best_dir = './results/best_model/' |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
save_file_UNET = './results/npo_cubic_e2e_MSE.csv' |
|
|
18 |
save_file_UNET_SSIM = './results/npo_cubic_e2e_SSIM.csv' |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
# read test image |
|
|
22 |
image_read_dir = './image/' |
|
|
23 |
GT_path = str(Path(image_read_dir + 'test_img.png')) |
|
|
24 |
gt_img = cv2.imread(GT_path, 0)/255 |
|
|
25 |
GT = np.tile(gt_img, [21, 1, 1]) |
|
|
26 |
GT = np.expand_dims(GT, -1) |
|
|
27 |
GT = tf.convert_to_tensor(GT, dtype=tf.float32) |
|
|
28 |
|
|
|
29 |
#save image |
|
|
30 |
image_save_dir = './image/results/' |
|
|
31 |
|
|
|
32 |
####### read from the TFRECORD format ################# |
|
|
33 |
## for faster reading from Hard disk |
|
|
34 |
def read_tfrecord(TFRECORD_PATH): |
|
|
35 |
# from tfrecord file to data |
|
|
36 |
N_w = 326 # size of the images |
|
|
37 |
N_h = 326 |
|
|
38 |
queue = tf.train.string_input_producer(TFRECORD_PATH, shuffle=True) |
|
|
39 |
reader = tf.TFRecordReader() |
|
|
40 |
|
|
|
41 |
_, serialized_example = reader.read(queue) |
|
|
42 |
|
|
|
43 |
features = tf.parse_single_example(serialized_example, |
|
|
44 |
features={ |
|
|
45 |
'sharp': tf.FixedLenFeature([], tf.string), |
|
|
46 |
}) |
|
|
47 |
|
|
|
48 |
RGB_flat = tf.decode_raw(features['sharp'], tf.uint8) |
|
|
49 |
RGB = tf.reshape(RGB_flat, [N_h, N_w, 1]) |
|
|
50 |
|
|
|
51 |
return RGB |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
########## Preprocess the images ############# |
|
|
55 |
## crop to patches |
|
|
56 |
## random flip |
|
|
57 |
## Add uniform noise |
|
|
58 |
############################################ |
|
|
59 |
def data_augment(RGB_batch_float): |
|
|
60 |
# crop to N_raw x N_raw |
|
|
61 |
N_raw = 326 # for boundary effect, 256+70, will need cropping after convolution |
|
|
62 |
|
|
|
63 |
data1 = tf.map_fn(lambda img: tf.random_crop(img, [N_raw, N_raw, 1]), RGB_batch_float) |
|
|
64 |
|
|
|
65 |
# flip both images and labels |
|
|
66 |
data2 = tf.map_fn(lambda img: tf.image.random_flip_up_down(tf.image.random_flip_left_right(img)), data1) |
|
|
67 |
|
|
|
68 |
# only adjust the RGB value of the image |
|
|
69 |
r1 = tf.random_uniform([]) * 0.3 + 0.8 |
|
|
70 |
RGB_out = data2 * r1 |
|
|
71 |
|
|
|
72 |
return RGB_out |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
|
|
|
76 |
############ Put data in batches ############# |
|
|
77 |
## put in batch and shuffle |
|
|
78 |
## cast to float32 |
|
|
79 |
## call data_augment for image preprocess |
|
|
80 |
## @param{TFRECORD_PATH}: path to the data |
|
|
81 |
## @param{batchsize}: currently 21 for the 21 PSFs |
|
|
82 |
############################################## |
|
|
83 |
def read2batch(TFRECORD_PATH, batchsize): |
|
|
84 |
# load tfrecord and make them to be usable data |
|
|
85 |
RGB = read_tfrecord(TFRECORD_PATH) |
|
|
86 |
#RGB_batch = tf.train.shuffle_batch([RGB], batch_size=batchsize, capacity=200, num_threads=5) |
|
|
87 |
RGB = tf.expand_dims(RGB, axis=0) |
|
|
88 |
RGB_batch = tf.tile(RGB, [21,1,1,1]) |
|
|
89 |
RGB_batch_float = tf.image.convert_image_dtype(RGB_batch, tf.float32) |
|
|
90 |
|
|
|
91 |
# padd the target for convolution |
|
|
92 |
RGB_batch_float = tf.image.resize_image_with_crop_or_pad(RGB_batch_float, 298, 298) |
|
|
93 |
|
|
|
94 |
return RGB_batch_float[:, :, :, 0:1] |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def add_gaussian_noise(images, std): |
|
|
99 |
noise = tf.random_normal(shape=tf.shape(images), mean=0.0, stddev=std, dtype=tf.float32) |
|
|
100 |
return tf.nn.relu(images + noise) |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
################ blur the images using PSFs ################## |
|
|
104 |
## same patch different depths put in a stack |
|
|
105 |
################################################################ |
|
|
106 |
def one_wvl_blur(im, PSFs0): |
|
|
107 |
N_B = PSFs0.shape[1].value |
|
|
108 |
N_Phi = PSFs0.shape[0].value |
|
|
109 |
N_im = im.shape[1].value |
|
|
110 |
N_im_out = N_im - N_B + 1 # the final image size after blurring |
|
|
111 |
|
|
|
112 |
sharp = tf.transpose(tf.reshape(im, [-1, N_Phi, N_im, N_im]), |
|
|
113 |
[0, 2, 3, 1]) # reshape to make N_Phi in the last channel |
|
|
114 |
PSFs = tf.expand_dims(tf.transpose(PSFs0, perm=[1, 2, 0]), -1) |
|
|
115 |
blurAll = tf.nn.depthwise_conv2d(sharp, PSFs, strides=[1, 1, 1, 1], padding='VALID') |
|
|
116 |
blurStack = tf.transpose( |
|
|
117 |
tf.reshape(tf.transpose(blurAll, perm=[0, 3, 1, 2]), [-1, 1, N_im_out, N_im_out]), |
|
|
118 |
perm=[0, 2, 3, 1]) # stack all N_Phi images to the first dimension |
|
|
119 |
|
|
|
120 |
return blurStack |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
def blurImage_diffPatch_diffDepth(RGB, PSFs): |
|
|
124 |
blur = one_wvl_blur(RGB[:, :, :, 0], PSFs[:, :, :, 0]) |
|
|
125 |
|
|
|
126 |
return blur |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
####################### system ########################## |
|
|
130 |
## @param{PSFs}: the PSFs |
|
|
131 |
## @param{RGB_batch_float}: patches |
|
|
132 |
## @param{phase_BN}: batch normalization, True only during training |
|
|
133 |
######################################################## |
|
|
134 |
def system(PSFs, RGB_batch_float, phase_BN=False): |
|
|
135 |
with tf.variable_scope("system", reuse=tf.AUTO_REUSE): |
|
|
136 |
blur = blurImage_diffPatch_diffDepth(RGB_batch_float, PSFs) # size [batch_size * N_Phi, Nx, Ny, 3] |
|
|
137 |
|
|
|
138 |
# noise |
|
|
139 |
sigma = 0.01 |
|
|
140 |
blur_noisy = add_gaussian_noise(blur, sigma) |
|
|
141 |
|
|
|
142 |
RGB_hat = Network.UNet(blur_noisy, phase_BN) |
|
|
143 |
|
|
|
144 |
return blur_noisy, RGB_hat |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
###################### RMS cost ############################# |
|
|
148 |
## @param{GT}: ground truth |
|
|
149 |
## @param{hat}: reconstruction |
|
|
150 |
############################################################## |
|
|
151 |
def cost_rms(GT, hat): |
|
|
152 |
cost = tf.sqrt(tf.reduce_mean(tf.reduce_mean((tf.square(GT - hat)),1),1)) |
|
|
153 |
return cost |
|
|
154 |
|
|
|
155 |
###################### SSIM cost ############################# |
|
|
156 |
## @param{GT}: ground truth |
|
|
157 |
## @param{hat}: reconstruction |
|
|
158 |
############################################################## |
|
|
159 |
def cost_ssim(GT, hat): |
|
|
160 |
cost = tf.image.ssim(GT, hat, 1.0) # assume img intensity ranges from 0 to 1 |
|
|
161 |
cost = tf.expand_dims(cost, axis = 1) |
|
|
162 |
return cost |
|
|
163 |
|
|
|
164 |
########## compare the reconstruction reblured with U-net input? ############ |
|
|
165 |
## important for EDOF to utilize the PSF information |
|
|
166 |
## @param{RGB_hat}: Unet reconstructed image |
|
|
167 |
## @param{PSFs}: PSF used |
|
|
168 |
## @param{blur}: all-in-focus image conv PSF |
|
|
169 |
## @param{N_B}: size of blur kernel |
|
|
170 |
## @return{reblur}: reconstruction blurred |
|
|
171 |
## @return{cost}: l2 norm between blur_GT and reblur |
|
|
172 |
############################################################################## |
|
|
173 |
def cost_reblur(RGB_hat, PSFs, blur, N_B): |
|
|
174 |
reblur = blurImage_diffPatch_diffDepth(RGB_hat, PSFs) |
|
|
175 |
blur_GT = blur[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), int((N_B - 1) / 2):-int((N_B - 1) / 2), |
|
|
176 |
:] # crop the patch to 256x256 |
|
|
177 |
|
|
|
178 |
cost = tf.sqrt(tf.reduce_mean(tf.square(blur_GT - reblur))) |
|
|
179 |
|
|
|
180 |
return reblur, cost |
|
|
181 |
|
|
|
182 |
|
|
|
183 |
######################################################## PARAMETER #################################################### |
|
|
184 |
N_B = 71 |
|
|
185 |
|
|
|
186 |
N_Phi = 21 |
|
|
187 |
batch_size = N_Phi |
|
|
188 |
Phi_list = np.linspace(-10, 10, N_Phi, np.float32) # defocus |
|
|
189 |
|
|
|
190 |
PSFs = np.load(results_dir + 'PSFs.npy') |
|
|
191 |
PSFs = tf.convert_to_tensor(PSFs, dtype=tf.float32) |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
####################################################### architecture ################################################### |
|
|
195 |
RGB_batch_float_test = GT |
|
|
196 |
|
|
|
197 |
[blur_test, RGB_hat_test] = system(PSFs, RGB_batch_float_test) |
|
|
198 |
|
|
|
199 |
# cost function |
|
|
200 |
with tf.name_scope("cost"): |
|
|
201 |
RGB_GT_test = RGB_batch_float_test[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), |
|
|
202 |
int((N_B - 1) / 2):-int((N_B - 1) / 2), :] # crop the all-in-focus to be |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
cost_rms_test = cost_rms(RGB_GT_test, RGB_hat_test) |
|
|
206 |
cost_ssim_test = cost_ssim(RGB_GT_test, RGB_hat_test) |
|
|
207 |
|
|
|
208 |
|
|
|
209 |
|
|
|
210 |
################################################### reload model ################################################## |
|
|
211 |
saver_best = tf.train.Saver() |
|
|
212 |
|
|
|
213 |
with tf.Session() as sess: |
|
|
214 |
|
|
|
215 |
# threading for parallel |
|
|
216 |
coord = tf.train.Coordinator() |
|
|
217 |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
|
|
218 |
|
|
|
219 |
sess.run(tf.global_variables_initializer()) |
|
|
220 |
sess.run(tf.local_variables_initializer()) |
|
|
221 |
|
|
|
222 |
model_path = tf.train.latest_checkpoint(results_dir) |
|
|
223 |
saver_best.restore(sess, model_path) |
|
|
224 |
|
|
|
225 |
for i in range(1): |
|
|
226 |
[loss_blur_test, loss_estimate_test, loss_rms_test, loss_ssim_test, GT_test] = sess.run( |
|
|
227 |
[blur_test, RGB_hat_test, cost_rms_test, cost_ssim_test, RGB_GT_test]) |
|
|
228 |
|
|
|
229 |
sharp_crop = GT_test[0, :, :, 0] |
|
|
230 |
gt_min = np.amin(np.ndarray.flatten(sharp_crop)) |
|
|
231 |
gt_max = np.amax(np.ndarray.flatten(sharp_crop)) |
|
|
232 |
scipy.misc.toimage(sharp_crop, cmin=gt_min, cmax=gt_max).save(image_save_dir + 'sharp_crop.png') |
|
|
233 |
|
|
|
234 |
np.savetxt(save_file_UNET, loss_rms_test, delimiter=',', newline='\n') |
|
|
235 |
np.savetxt(save_file_UNET_SSIM, loss_ssim_test, delimiter=',', newline='\n') |
|
|
236 |
|
|
|
237 |
np.save('blur.npy', loss_blur_test) |
|
|
238 |
np.save('estimate.npy', loss_estimate_test) |
|
|
239 |
|
|
|
240 |
|
|
|
241 |
coord.request_stop() |
|
|
242 |
coord.join(threads) |
|
|
243 |
|
|
|
244 |
print('Now saving the images') |
|
|
245 |
|
|
|
246 |
mask_blur = np.load('blur.npy') |
|
|
247 |
estimate = np.load('estimate.npy') |
|
|
248 |
|
|
|
249 |
|
|
|
250 |
def npy_to_images(npy_stack, save_name): |
|
|
251 |
for i in range(21): |
|
|
252 |
img_cur = npy_stack[i, :, :, 0] |
|
|
253 |
img_cur = 1 - img_cur |
|
|
254 |
img_min = np.amin(np.ndarray.flatten(img_cur)) |
|
|
255 |
img_max = np.amax(np.ndarray.flatten(img_cur)) |
|
|
256 |
save_name_cur = '00_'+ str(i) + '_' + save_name |
|
|
257 |
scipy.misc.toimage(img_cur, cmin=img_min, cmax=img_max).save(image_save_dir + save_name_cur) |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
npy_to_images(mask_blur, 'deepDOF_blur.png') |
|
|
261 |
npy_to_images(estimate, 'deepDOF_hat.png') |