a b/test_image_all_720um.py
1
import tensorflow as tf
2
import numpy as np
3
import cv2
4
from pathlib import Path
5
import scipy.misc
6
import Network
7
import os
8
9
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
10
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # only uses GPU 1
11
12
13
results_dir = './results/'
14
results_best_dir = './results/best_model/'
15
16
17
save_file_UNET = './results/npo_cubic_e2e_MSE.csv'
18
save_file_UNET_SSIM = './results/npo_cubic_e2e_SSIM.csv'
19
20
21
# read test image
22
image_read_dir = './image/'
23
GT_path = str(Path(image_read_dir + 'test_img.png'))
24
gt_img = cv2.imread(GT_path, 0)/255
25
GT = np.tile(gt_img, [21, 1, 1])
26
GT = np.expand_dims(GT, -1)
27
GT = tf.convert_to_tensor(GT, dtype=tf.float32)
28
29
#save image
30
image_save_dir = './image/results/'
31
32
####### read from the TFRECORD format #################
33
## for faster reading from Hard disk
34
def read_tfrecord(TFRECORD_PATH):
35
    # from tfrecord file to data
36
    N_w = 326  # size of the images
37
    N_h = 326
38
    queue = tf.train.string_input_producer(TFRECORD_PATH, shuffle=True)
39
    reader = tf.TFRecordReader()
40
41
    _, serialized_example = reader.read(queue)
42
43
    features = tf.parse_single_example(serialized_example,
44
                                       features={
45
                                           'sharp': tf.FixedLenFeature([], tf.string),
46
                                       })
47
48
    RGB_flat = tf.decode_raw(features['sharp'], tf.uint8)
49
    RGB = tf.reshape(RGB_flat, [N_h, N_w, 1])
50
51
    return RGB
52
53
54
########## Preprocess the images #############
55
##  crop to patches
56
##  random flip
57
##  Add uniform noise
58
############################################  
59
def data_augment(RGB_batch_float):
60
    # crop to N_raw x N_raw
61
    N_raw = 326  # for boundary effect, 256+70, will need cropping after convolution
62
63
    data1 = tf.map_fn(lambda img: tf.random_crop(img, [N_raw, N_raw, 1]), RGB_batch_float)
64
65
    # flip both images and labels
66
    data2 = tf.map_fn(lambda img: tf.image.random_flip_up_down(tf.image.random_flip_left_right(img)), data1)
67
68
    # only adjust the RGB value of the image
69
    r1 = tf.random_uniform([]) * 0.3 + 0.8
70
    RGB_out = data2 * r1
71
72
    return RGB_out
73
74
75
76
############ Put data in batches #############
77
##  put in batch and shuffle
78
##  cast to float32
79
##  call data_augment for image preprocess
80
## @param{TFRECORD_PATH}: path to the data
81
## @param{batchsize}: currently 21 for the 21 PSFs
82
##############################################
83
def read2batch(TFRECORD_PATH, batchsize):
84
    # load tfrecord and make them to be usable data
85
    RGB = read_tfrecord(TFRECORD_PATH)
86
    #RGB_batch = tf.train.shuffle_batch([RGB], batch_size=batchsize, capacity=200, num_threads=5)
87
    RGB = tf.expand_dims(RGB, axis=0)
88
    RGB_batch = tf.tile(RGB, [21,1,1,1])
89
    RGB_batch_float = tf.image.convert_image_dtype(RGB_batch, tf.float32)
90
91
    # padd the target for convolution
92
    RGB_batch_float = tf.image.resize_image_with_crop_or_pad(RGB_batch_float, 298, 298)
93
94
    return RGB_batch_float[:, :, :, 0:1]
95
96
97
98
def add_gaussian_noise(images, std):
99
    noise = tf.random_normal(shape=tf.shape(images), mean=0.0, stddev=std, dtype=tf.float32)
100
    return tf.nn.relu(images + noise)
101
102
103
################  blur the images using PSFs  ##################
104
## same patch different depths put in a stack
105
################################################################
106
def one_wvl_blur(im, PSFs0):
107
    N_B = PSFs0.shape[1].value
108
    N_Phi = PSFs0.shape[0].value
109
    N_im = im.shape[1].value
110
    N_im_out = N_im - N_B + 1  # the final image size after blurring
111
112
    sharp = tf.transpose(tf.reshape(im, [-1, N_Phi, N_im, N_im]),
113
                         [0, 2, 3, 1])  # reshape to make N_Phi in the last channel
114
    PSFs = tf.expand_dims(tf.transpose(PSFs0, perm=[1, 2, 0]), -1)
115
    blurAll = tf.nn.depthwise_conv2d(sharp, PSFs, strides=[1, 1, 1, 1], padding='VALID')
116
    blurStack = tf.transpose(
117
        tf.reshape(tf.transpose(blurAll, perm=[0, 3, 1, 2]), [-1, 1, N_im_out, N_im_out]),
118
        perm=[0, 2, 3, 1])  # stack all N_Phi images to the first dimension
119
120
    return blurStack
121
122
123
def blurImage_diffPatch_diffDepth(RGB, PSFs):
124
    blur = one_wvl_blur(RGB[:, :, :, 0], PSFs[:, :, :, 0])
125
126
    return blur
127
128
129
####################### system ##########################
130
## @param{PSFs}: the PSFs
131
## @param{RGB_batch_float}: patches
132
## @param{phase_BN}: batch normalization, True only during training
133
########################################################
134
def system(PSFs, RGB_batch_float, phase_BN=False):
135
    with tf.variable_scope("system", reuse=tf.AUTO_REUSE):
136
        blur = blurImage_diffPatch_diffDepth(RGB_batch_float, PSFs)  # size [batch_size * N_Phi, Nx, Ny, 3]
137
138
        # noise
139
        sigma = 0.01
140
        blur_noisy = add_gaussian_noise(blur, sigma)
141
142
        RGB_hat = Network.UNet(blur_noisy, phase_BN)
143
144
        return blur_noisy, RGB_hat
145
146
147
######################  RMS cost #############################
148
## @param{GT}: ground truth
149
## @param{hat}: reconstruction
150
##############################################################
151
def cost_rms(GT, hat):
152
    cost = tf.sqrt(tf.reduce_mean(tf.reduce_mean((tf.square(GT - hat)),1),1))
153
    return cost
154
155
######################  SSIM cost #############################
156
## @param{GT}: ground truth
157
## @param{hat}: reconstruction
158
##############################################################
159
def cost_ssim(GT, hat):
160
    cost = tf.image.ssim(GT, hat, 1.0) # assume img intensity ranges from 0 to 1
161
    cost = tf.expand_dims(cost, axis = 1)
162
    return cost
163
164
##########  compare the reconstruction reblured with U-net input?  ############
165
## important for EDOF to utilize the PSF information
166
## @param{RGB_hat}: Unet reconstructed image
167
## @param{PSFs}: PSF used
168
## @param{blur}: all-in-focus image conv PSF
169
## @param{N_B}: size of blur kernel
170
## @return{reblur}: reconstruction blurred
171
## @return{cost}: l2 norm between blur_GT and reblur
172
##############################################################################
173
def cost_reblur(RGB_hat, PSFs, blur, N_B):
174
    reblur = blurImage_diffPatch_diffDepth(RGB_hat, PSFs)
175
    blur_GT = blur[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), int((N_B - 1) / 2):-int((N_B - 1) / 2),
176
              :]  # crop the patch to 256x256
177
178
    cost = tf.sqrt(tf.reduce_mean(tf.square(blur_GT - reblur)))
179
180
    return reblur, cost
181
182
183
########################################################  PARAMETER ####################################################
184
N_B = 71
185
186
N_Phi = 21
187
batch_size = N_Phi
188
Phi_list = np.linspace(-10, 10, N_Phi, np.float32)  # defocus
189
190
PSFs = np.load(results_dir + 'PSFs.npy')
191
PSFs = tf.convert_to_tensor(PSFs, dtype=tf.float32)
192
193
194
####################################################### architecture ###################################################
195
RGB_batch_float_test = GT
196
197
[blur_test, RGB_hat_test] = system(PSFs, RGB_batch_float_test)
198
199
# cost function
200
with tf.name_scope("cost"):
201
    RGB_GT_test = RGB_batch_float_test[:, int((N_B - 1) / 2):-int((N_B - 1) / 2),
202
                  int((N_B - 1) / 2):-int((N_B - 1) / 2), :]  # crop the all-in-focus to be
203
204
205
    cost_rms_test = cost_rms(RGB_GT_test, RGB_hat_test)
206
    cost_ssim_test = cost_ssim(RGB_GT_test, RGB_hat_test)
207
208
209
210
###################################################  reload model ##################################################
211
saver_best = tf.train.Saver()
212
213
with tf.Session() as sess:
214
215
    # threading for parallel
216
    coord = tf.train.Coordinator()
217
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
218
219
    sess.run(tf.global_variables_initializer())
220
    sess.run(tf.local_variables_initializer())
221
222
    model_path = tf.train.latest_checkpoint(results_dir)
223
    saver_best.restore(sess, model_path)
224
225
    for i in range(1):
226
        [loss_blur_test, loss_estimate_test, loss_rms_test, loss_ssim_test, GT_test] = sess.run(
227
            [blur_test, RGB_hat_test, cost_rms_test, cost_ssim_test, RGB_GT_test])
228
229
        sharp_crop = GT_test[0, :, :, 0]
230
        gt_min = np.amin(np.ndarray.flatten(sharp_crop))
231
        gt_max = np.amax(np.ndarray.flatten(sharp_crop))
232
        scipy.misc.toimage(sharp_crop, cmin=gt_min, cmax=gt_max).save(image_save_dir + 'sharp_crop.png')
233
234
        np.savetxt(save_file_UNET, loss_rms_test, delimiter=',', newline='\n')
235
        np.savetxt(save_file_UNET_SSIM, loss_ssim_test, delimiter=',', newline='\n')
236
237
    np.save('blur.npy', loss_blur_test)
238
    np.save('estimate.npy', loss_estimate_test)
239
240
241
    coord.request_stop()
242
    coord.join(threads)
243
244
print('Now saving the images')
245
246
mask_blur = np.load('blur.npy')
247
estimate = np.load('estimate.npy')
248
249
250
def npy_to_images(npy_stack, save_name):
251
    for i in range(21):
252
        img_cur = npy_stack[i, :, :, 0]
253
        img_cur = 1 - img_cur
254
        img_min = np.amin(np.ndarray.flatten(img_cur))
255
        img_max = np.amax(np.ndarray.flatten(img_cur))
256
        save_name_cur = '00_'+ str(i) + '_' + save_name
257
        scipy.misc.toimage(img_cur, cmin=img_min, cmax=img_max).save(image_save_dir + save_name_cur)
258
259
260
npy_to_images(mask_blur, 'deepDOF_blur.png')
261
npy_to_images(estimate, 'deepDOF_hat.png')