Diff of /DeepDOF_step2.py [000000] .. [0b950f]

Switch to unified view

a b/DeepDOF_step2.py
1
# End-to-end optimization for EDOF
2
# Author: Yicheng Wu @ Rice University
3
# 03/29/2019
4
# 04/12/2019 parameter update
5
# 11/7/2019 update best model with valid_loss
6
# 12/3/2019 update best model with valid_rms instead of valid_loss
7
# 12/3/2019 update reblur cost = rms(blur, reblur)
8
9
import tensorflow as tf
10
import scipy.io as sio
11
import numpy as np
12
import os
13
import Network
14
15
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
16
os.environ["CUDA_VISIBLE_DEVICES"] = "0"            # only uses GPU 1
17
18
results_dir = "./results/"
19
DATA_PATH = './DATA/'
20
TFRECORD_TRAIN_PATH = [DATA_PATH + 'npo_720um_train.tfrecords']  # for testing purpose both are validation sets
21
TFRECORD_VALID_PATH = [DATA_PATH + 'npo_720um_train.tfrecords']
22
23
## optimizer learning rates
24
# use 0 in step 1:
25
# lr_optical = 0
26
# use 1e-9 in step 2:
27
lr_optical = 1e-9
28
lr_digital = 1e-4
29
print('lr_optical:' + str(lr_optical))
30
print('lr_digital:' + str(lr_digital))
31
32
33
##########################################   Functions  #############################################
34
35
# Peak SNR, could be used as cost function
36
def tf_PSNR(a, b, max_val, name=None):
37
    with tf.name_scope(name, 'PSNR', [a, b]):
38
        # Need to convert the images to float32.  Scale max_val accordingly so that
39
        # PSNR is computed correctly.
40
        max_val = tf.cast(max_val, tf.float32)
41
        a = tf.cast(a, tf.float32)
42
        b = tf.cast(b, tf.float32)
43
        mse = tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1])
44
        psnr_val = tf.subtract(
45
            20 * tf.log(max_val) / tf.log(10.0),
46
            np.float32(10 / np.log(10)) * tf.log(mse),
47
            name='psnr')
48
49
        return psnr_val
50
51
52
####### read from the TFRECORD format #################
53
## for faster reading from Hard disk
54
def read_tfrecord(TFRECORD_PATH):
55
    # from tfrecord file to data
56
    N_w = 1000 # size of the images
57
    N_h = 1000
58
    queue = tf.train.string_input_producer(TFRECORD_PATH, shuffle=True)
59
    reader = tf.TFRecordReader()
60
61
    _, serialized_example = reader.read(queue)  
62
63
    features = tf.parse_single_example(serialized_example,
64
                                       features={
65
                                           'sharp': tf.FixedLenFeature([], tf.string),
66
                                       })
67
68
    RGB_flat = tf.decode_raw(features['sharp'], tf.uint8)
69
    RGB = tf.reshape(RGB_flat, [N_h, N_w, 1]) 
70
71
    return RGB
72
73
74
75
########## Preprocess the images #############
76
##  crop to patches
77
##  random flip
78
##  Add uniform noise
79
############################################  
80
def data_augment(RGB_batch_float):
81
    # crop to N_raw x N_raw
82
    N_raw = 326 # for boundary effect, 256+70, will need cropping after convolution
83
    data1 = tf.map_fn(lambda img: tf.random_crop(img, [N_raw, N_raw, 1]), RGB_batch_float)
84
85
    # flip both images and labels
86
    data2 = tf.map_fn(lambda img: tf.image.random_flip_up_down(tf.image.random_flip_left_right(img)), data1)
87
88
    # only adjust the RGB value of the image
89
    r1 = tf.random_uniform([]) * 0.3 + 0.8
90
    RGB_out = data2 * r1
91
92
    return RGB_out
93
94
95
96
############ Put data in batches #############
97
##  put in batch and shuffle
98
##  cast to float32
99
##  call data_augment for image preprocess
100
## @param{TFRECORD_PATH}: path to the data
101
## @param{batchsize}: currently 21 for the 21 PSFs
102
##############################################
103
def read2batch(TFRECORD_PATH, batchsize):
104
    # load tfrecord and make them to be usable data
105
    RGB = read_tfrecord(TFRECORD_PATH)
106
    RGB_batch = tf.train.shuffle_batch([RGB], batch_size=batchsize, capacity=200,
107
                                       min_after_dequeue=50, num_threads=5)
108
    RGB_batch_float = tf.image.convert_image_dtype(RGB_batch, tf.float32)
109
110
    RGB_batch_float = data_augment(RGB_batch_float)
111
112
    return RGB_batch_float[:,:,:,0:1]
113
114
115
def add_gaussian_noise(images, std):
116
    noise = tf.random_normal(shape=tf.shape(images), mean=0.0, stddev=std, dtype=tf.float32)
117
    return tf.nn.relu(images + noise)
118
119
120
121
122
########### fftshift2D ###################
123
## the same as fftshift in MATLAB
124
## works for complex number
125
def fft2dshift(input):
126
    dim = int(input.shape[1].value)  # dimension of the data
127
    channel1 = int(input.shape[0].value)  # channels for the first dimension
128
    if dim % 2 == 0:
129
        # even version
130
        # shift up and down
131
        u = tf.slice(input, [0, 0, 0], [channel1, int((dim) / 2), dim])
132
        d = tf.slice(input, [0, int((dim) / 2), 0], [channel1, int((dim) / 2), dim])
133
        du = tf.concat([d, u], axis=1)
134
        # shift left and right
135
        l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim) / 2)])
136
        r = tf.slice(du, [0, 0, int((dim) / 2)], [channel1, dim, int((dim) / 2)])
137
        output = tf.concat([r, l], axis=2)
138
    else:
139
        # odd version
140
        # shift up and down
141
        u = tf.slice(input, [0, 0, 0], [channel1, int((dim + 1) / 2), dim])
142
        d = tf.slice(input, [0, int((dim + 1) / 2), 0], [channel1, int((dim - 1) / 2), dim])
143
        du = tf.concat([d, u], axis=1)
144
        # shift left and right
145
        l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim + 1) / 2)])
146
        r = tf.slice(du, [0, 0, int((dim + 1) / 2)], [channel1, dim, int((dim - 1) / 2)])
147
        output = tf.concat([r, l], axis=2)
148
    return output
149
150
151
152
#########  generate out-of-focus phase  ###############
153
## @param{Phi_list}: a list of Phi values
154
## @param{N_B}: size of the blur kernel
155
## @return{OOFphase} 
156
def gen_OOFphase(Phi_list, N_B):
157
    # return (Phi_list,pixel,pixel,color)
158
    N = N_B
159
    x0 = np.linspace(-2.84, 2.84, N) # 71/25 =2.84
160
    xx, yy = np.meshgrid(x0, x0)
161
    OOFphase = np.empty([len(Phi_list), N, N, 1], dtype=np.float32)
162
    for j in range(len(Phi_list)):
163
        Phi = Phi_list[j]
164
        OOFphase[j, :, :, 0] = Phi * (xx ** 2 + yy ** 2)
165
    return OOFphase
166
167
168
169
##################  Generates the PSFs  ########################
170
## @param{h}: height map of the mask
171
## @param{OOFphase}: out-of-focus phase
172
## @param{wvls}: wavelength \lambda
173
## @param{idx}: index of the PSF
174
## @param{N_B}: size of the blur kernel
175
#################################################################
176
def gen_PSFs(h, OOFphase, wvls, idx, N_B):
177
    n = 1.5  # diffractive index
178
179
    with tf.variable_scope("PSFs"):
180
        OOFphase_B = OOFphase[:, :, :, 0]
181
        phase_B = tf.add(2 * np.pi / wvls[0] * (n - 1) * h, OOFphase_B) # phase modulation of mask (phi_M)
182
        Pupil_B = tf.multiply(tf.complex(idx, 0.0), tf.exp(tf.complex(0.0, phase_B)), name='Pupil_B') # pupil P
183
        Norm_B = tf.cast(N_B * N_B * np.sum(idx ** 2), tf.float32)  # what's this?
184
        PSF_B = tf.divide(tf.square(tf.abs(fft2dshift(tf.fft2d(Pupil_B)))), Norm_B, name='PSF_B')
185
186
    return tf.expand_dims(PSF_B, -1)
187
188
189
190
################  blur the images using PSFs  ##################
191
## same patch different depths put in a stack
192
################################################################
193
def one_wvl_blur(im, PSFs0):
194
    N_B = PSFs0.shape[1].value
195
    N_Phi = PSFs0.shape[0].value
196
    N_im = im.shape[1].value
197
    N_im_out = N_im - N_B + 1  # the final image size after blurring
198
199
    sharp = tf.transpose(tf.reshape(im, [-1, N_Phi, N_im, N_im]),
200
                         [0, 2, 3, 1])  # reshape to make N_Phi in the last channel
201
    PSFs = tf.expand_dims(tf.transpose(PSFs0, perm=[1, 2, 0]), -1)
202
    blurAll = tf.nn.depthwise_conv2d(sharp, PSFs, strides=[1, 1, 1, 1], padding='VALID')
203
    blurStack = tf.transpose(
204
        tf.reshape(tf.transpose(blurAll, perm=[0, 3, 1, 2]), [-1, 1, N_im_out, N_im_out]),
205
        perm=[0, 2, 3, 1])  # stack all N_Phi images to the first dimension
206
207
    return blurStack
208
209
210
def blurImage_diffPatch_diffDepth(RGB, PSFs):
211
    blur = one_wvl_blur(RGB[:, :, :, 0], PSFs[:, :, :, 0])
212
213
    return blur
214
215
216
####################### system ##########################
217
## @param{PSFs}: the PSFs
218
## @param{RGB_batch_float}: patches
219
## @param{phase_BN}: batch normalization, True only during training
220
########################################################
221
def system(PSFs, RGB_batch_float, phase_BN=True): 
222
    with tf.variable_scope("system", reuse=tf.AUTO_REUSE):
223
        blur = blurImage_diffPatch_diffDepth(RGB_batch_float, PSFs)  # size [batch_size * N_Phi, Nx, Ny, 3]
224
225
        # noise
226
        sigma = 0.01
227
        blur_noisy = add_gaussian_noise(blur, sigma)
228
229
        RGB_hat = Network.UNet(blur_noisy, phase_BN)
230
231
        return blur, RGB_hat
232
233
234
######################  RMS cost #############################
235
## @param{GT}: ground truth
236
## @param{hat}: reconstruction
237
##############################################################
238
def cost_rms(GT, hat):
239
    cost = tf.sqrt(tf.reduce_mean(tf.square(GT - hat)))
240
    return cost
241
242
243
##########  compare the reconstruction reblured with U-net input?  ############
244
## important for EDOF to utilize the PSF information
245
## @param{RGB_hat}: Unet reconstructed image
246
## @param{PSFs}: PSF used
247
## @param{blur}: all-in-focus image conv PSF
248
## @param{N_B}: size of blur kernel
249
## @return{reblur}: reconstruction blurred
250
## @return{cost}: l2 norm between blur_GT and reblur
251
##############################################################################
252
def cost_reblur(RGB_hat, PSFs, blur, N_B):
253
    reblur = blurImage_diffPatch_diffDepth(RGB_hat, PSFs)
254
    blur_GT = blur[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), int((N_B - 1) / 2):-int((N_B - 1) / 2), :] #crop the patch to 256x256
255
256
    cost = tf.sqrt(tf.reduce_mean(tf.square(blur_GT - reblur)))
257
258
    return reblur, cost
259
260
261
######################################### Set parameters   ###############################################
262
263
# def main():
264
265
zernike = sio.loadmat('zernike_basis_150mm.mat')
266
u2 = zernike['u2']  # basis of zernike poly
267
idx = zernike['idx']
268
idx = idx.astype(np.float32)
269
270
a_zernike_mat = sio.loadmat('a_zernike_cubic_150mm.mat')
271
a_zernike_fix = a_zernike_mat['a']
272
a_zernike_fix = a_zernike_fix * 4
273
a_zernike_fix = tf.convert_to_tensor(a_zernike_fix)
274
275
N_B = 71  # size of the blur kernel
276
wvls = np.array([550]) * 1e-9 # wavelength 550 nm
277
N_color = len(wvls)
278
279
N_modes = u2.shape[1]  # load zernike modes
280
281
# generate the defocus phase
282
N_Phi = 21
283
Phi_list = np.linspace(-10, 10, N_Phi, np.float32) # defocus
284
OOFphase = gen_OOFphase(Phi_list, N_B)  # return (N_Phi,N_B,N_B,N_color)
285
286
# baseline offset for the heightmap
287
c = 0
288
289
####################################   Build the architecture  #####################################################
290
291
292
with tf.variable_scope("PSFs"):
293
    a_zernike_learn = tf.get_variable("a_zernike_learn", [N_modes, 1], initializer=tf.zeros_initializer(),
294
                                constraint=lambda x: tf.clip_by_value(x, -wvls[0] / 2, wvls[0] / 2))
295
    a_zernike = a_zernike_learn + a_zernike_fix # fixed cubic and learning part
296
    g = tf.matmul(u2, a_zernike)
297
    h = tf.nn.relu(tf.reshape(g, [N_B, N_B])+c, # c: baseline
298
                   name='heightMap')  # height map of the phase mask, should be all positive
299
    PSFs = gen_PSFs(h, OOFphase, wvls, idx, N_B)  # return (N_Phi, N_B, N_B, N_color)
300
301
302
batch_size = N_Phi  # it means that each patch is blurred at different depth. Will be an error if this is not N_Phi
303
304
305
RGB_batch_float = read2batch(TFRECORD_TRAIN_PATH, batch_size)
306
RGB_batch_float_valid = read2batch(TFRECORD_VALID_PATH, batch_size)
307
308
[blur_train, RGB_hat_train] = system(PSFs, RGB_batch_float)
309
[blur_valid, RGB_hat_valid] = system(PSFs, RGB_batch_float_valid, phase_BN=False)
310
311
# cost function
312
with tf.name_scope("cost"):
313
    RGB_GT_train = RGB_batch_float[:, int((N_B - 1) / 2):-int((N_B - 1) / 2),
314
                   int((N_B - 1) / 2):-int((N_B - 1) / 2), :]                    # crop the all-in-focus to be 
315
    RGB_GT_valid = RGB_batch_float_valid[:, int((N_B - 1) / 2):-int((N_B - 1) / 2),
316
                   int((N_B - 1) / 2):-int((N_B - 1) / 2), :]
317
318
    cost_rms_train = cost_rms(RGB_GT_train, RGB_hat_train)
319
    cost_rms_valid = cost_rms(RGB_GT_valid, RGB_hat_valid)
320
    cost_train = cost_rms_train
321
    cost_valid = cost_rms_valid
322
323
# train ditial and optical part saparetely
324
vars_optical = tf.trainable_variables("PSFs")
325
vars_digital = tf.trainable_variables("system")
326
327
opt_optical = tf.train.AdamOptimizer(lr_optical)
328
opt_digital = tf.train.AdamOptimizer(lr_digital)
329
330
global_step = tf.Variable(0, name='global_step', trainable=False)  # initialize the stepsize
331
332
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # update the variables with gradient descent
333
with tf.control_dependencies(update_ops):
334
    grads = tf.gradients(cost_train, vars_optical + vars_digital)
335
    grads_optical = grads[:len(vars_optical)]
336
    grads_digital = grads[len(vars_optical):]
337
    train_op_optical = opt_optical.apply_gradients(zip(grads_optical, vars_optical))
338
    train_op_digital = opt_digital.apply_gradients(zip(grads_digital, vars_digital))
339
    train_op = tf.group(train_op_optical, train_op_digital)
340
341
# tensorboard
342
tf.summary.scalar('cost_train', cost_train)
343
tf.summary.scalar('cost_valid', cost_valid)
344
tf.summary.scalar('cost_rms_train', cost_rms_train)
345
tf.summary.scalar('cost_rms_valid', cost_rms_valid)
346
347
tf.summary.histogram('a_zernike', a_zernike)
348
tf.summary.histogram('a_zernike_learn', a_zernike_learn)
349
tf.summary.histogram('a_zernike_fix', a_zernike_fix)
350
tf.summary.image('Height', tf.expand_dims(tf.expand_dims(h, 0), -1))
351
tf.summary.image('sharp_valid', tf.image.convert_image_dtype(RGB_GT_valid[0:1, :, :, :], dtype = tf.uint8))
352
tf.summary.image('blur_valid', tf.image.convert_image_dtype(blur_valid[0:1, :, :, :], dtype = tf.uint8))
353
tf.summary.image('RGB_hat_valid', tf.image.convert_image_dtype(RGB_hat_valid[0:1, :, :, :], dtype = tf.uint8))
354
tf.summary.image('PSF_n100', PSFs[0:1,:,:])
355
tf.summary.image('PSF_p90', PSFs[19:20,:,:])
356
tf.summary.image('PSF_p100', PSFs[20:21,:,:])
357
merged = tf.summary.merge_all()
358
359
##########################################   Train  #############################################
360
361
# variables_to_restore = [v for v in tf.global_variables() if v.name.startswith('system')]
362
# saver = tf.train.Saver(variables_to_restore)
363
saver_all = tf.train.Saver(max_to_keep=1)
364
saver_best = tf.train.Saver()
365
366
with tf.Session() as sess:
367
    sess.run(tf.global_variables_initializer())
368
    sess.run(tf.local_variables_initializer())
369
370
    if not os.path.exists(results_dir):
371
        os.makedirs(results_dir)
372
373
    best_dir = 'best_model/'
374
    if not os.path.exists(results_dir + best_dir):
375
        os.makedirs(results_dir + best_dir)
376
        best_valid_loss = 100
377
    else:
378
        best_valid_loss = np.loadtxt(results_dir + 'best_valid_loss.txt')
379
        print('Current best valid loss = ' + str(best_valid_loss))
380
381
    if not tf.train.checkpoint_exists(results_dir + 'checkpoint'):
382
        # option1: run a new one
383
        out_all = np.empty((0, 2))  # for out_all 4D: [train_loss,valid_loss,train_acc,valid_acc]
384
        print('Start to save at: ', results_dir)
385
    else:
386
        print(results_dir)
387
        model_path = tf.train.latest_checkpoint(results_dir)
388
        load_path = saver_all.restore(sess, model_path)
389
        out_all = np.load(results_dir + 'out_all.npy')
390
        print('Continue to save at: ', results_dir)
391
392
    train_writer = tf.summary.FileWriter(results_dir + '/summary/', sess.graph)
393
394
    # threading for parallel 
395
    coord = tf.train.Coordinator()
396
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
397
398
    for i in range(100000):
399
        ## load the batch
400
        train_op.run()  # only train digital part
401
402
        if i % 10 == 0:
403
            [train_summary, loss_train, loss_valid, loss_rms_valid] = sess.run(
404
                [merged, cost_train, cost_valid, cost_rms_valid])
405
            train_writer.add_summary(train_summary, i)
406
407
            print("Iter " + str(i) + ", Train Loss = " + \
408
                  "{:.6f}".format(loss_train) + ", Valid Loss = " + \
409
                  "{:.6f}".format(loss_valid))
410
411
            # save them
412
            out = np.array([[loss_train, loss_valid]])
413
            out_all = np.vstack((out_all, out))
414
            np.save(results_dir + 'out_all.npy', out_all)
415
416
            saver_all.save(sess, results_dir + "model.ckpt", global_step=i)
417
418
            [ht, at, PSFst] = sess.run([h, a_zernike, PSFs])
419
            np.savetxt(results_dir + 'HeightMap.txt', ht)
420
            np.savetxt(results_dir + 'a_zernike.txt', at)
421
            np.save(results_dir + 'PSFs.npy', PSFst)
422
423
            if (loss_rms_valid < best_valid_loss) and (i > 1):
424
                best_valid_loss = loss_rms_valid
425
                np.savetxt(results_dir + 'best_valid_loss.txt', [best_valid_loss])
426
                saver_best.save(sess, results_dir + best_dir + "model.ckpt")
427
                np.save(results_dir + best_dir + 'out_all.npy', out_all)
428
                np.savetxt(results_dir + best_dir + 'HeightMap.txt', ht)
429
                print('best at iter ' + str(i) + ' with loss = ' + str(best_valid_loss))
430
431
    train_writer.close()
432
    coord.request_stop()
433
    coord.join(threads)