Diff of /DeepDOF_step1.py [000000] .. [0b950f]

Switch to side-by-side view

--- a
+++ b/DeepDOF_step1.py
@@ -0,0 +1,437 @@
+# End-to-end optimization for EDOF
+# Author: Yicheng Wu @ Rice University
+# 03/29/2019
+# 04/12/2019 parameter update
+# 11/7/2019 update best model with valid_loss
+# 12/3/2019 update best model with valid_rms instead of valid_loss
+# 12/3/2019 update reblur cost = rms(blur, reblur)
+
+import tensorflow as tf
+import scipy.io as sio
+import numpy as np
+import os
+import Network
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"            # only uses GPU 1
+
+results_dir = "./results/"
+DATA_PATH = "./DATA/"
+TFRECORD_TRAIN_PATH = [DATA_PATH + 'npo_720um_train.tfrecords']  # for testing purpose both are validation sets
+TFRECORD_VALID_PATH = [DATA_PATH + 'npo_720um_train.tfrecords']
+
+## optimizer learning rates
+# use 0 in step 1:
+lr_optical = 0
+# use 1e-9 in step 2:
+# lr_optical = 1e-9
+lr_digital = 1e-4
+print('lr_optical:' + str(lr_optical))
+print('lr_digital:' + str(lr_digital))
+
+
+##########################################   Functions  #############################################
+
+# Peak SNR, could be used as cost function
+def tf_PSNR(a, b, max_val, name=None):
+    with tf.name_scope(name, 'PSNR', [a, b]):
+        # Need to convert the images to float32.  Scale max_val accordingly so that
+        # PSNR is computed correctly.
+        max_val = tf.cast(max_val, tf.float32)
+        a = tf.cast(a, tf.float32)
+        b = tf.cast(b, tf.float32)
+        mse = tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1])
+        psnr_val = tf.subtract(
+            20 * tf.log(max_val) / tf.log(10.0),
+            np.float32(10 / np.log(10)) * tf.log(mse),
+            name='psnr')
+
+        return psnr_val
+
+
+####### read from the TFRECORD format #################
+## for faster reading from Hard disk
+def read_tfrecord(TFRECORD_PATH):
+    # from tfrecord file to data
+    N_w = 1000 # size of the images
+    N_h = 1000
+    queue = tf.train.string_input_producer(TFRECORD_PATH, shuffle=True)
+    reader = tf.TFRecordReader()
+
+    _, serialized_example = reader.read(queue)  
+
+    features = tf.parse_single_example(serialized_example,
+                                       features={
+                                           'sharp': tf.FixedLenFeature([], tf.string),
+                                       })
+
+    RGB_flat = tf.decode_raw(features['sharp'], tf.uint8)
+    RGB = tf.reshape(RGB_flat, [N_h, N_w, 1]) 
+
+    return RGB
+
+
+
+########## Preprocess the images #############
+##  crop to patches
+##  random flip
+##  Add uniform noise
+############################################  
+def data_augment(RGB_batch_float):
+    # crop to N_raw x N_raw
+    N_raw = 326 # for boundary effect, 256+70, will need cropping after convolution
+    data1 = tf.map_fn(lambda img: tf.random_crop(img, [N_raw, N_raw, 1]), RGB_batch_float)
+
+    # flip both images and labels
+    data2 = tf.map_fn(lambda img: tf.image.random_flip_up_down(tf.image.random_flip_left_right(img)), data1)
+
+    # only adjust the RGB value of the image
+    r1 = tf.random_uniform([]) * 0.3 + 0.8
+    RGB_out = data2 * r1
+
+    return RGB_out
+
+
+
+############ Put data in batches #############
+##  put in batch and shuffle
+##  cast to float32
+##  call data_augment for image preprocess
+## @param{TFRECORD_PATH}: path to the data
+## @param{batchsize}: currently 21 for the 21 PSFs
+##############################################
+def read2batch(TFRECORD_PATH, batchsize):
+    # load tfrecord and make them to be usable data
+    RGB = read_tfrecord(TFRECORD_PATH)
+    RGB_batch = tf.train.shuffle_batch([RGB], batch_size=batchsize, capacity=200,
+                                       min_after_dequeue=50, num_threads=5)
+    RGB_batch_float = tf.image.convert_image_dtype(RGB_batch, tf.float32)
+
+    RGB_batch_float = data_augment(RGB_batch_float)
+
+    return RGB_batch_float[:,:,:,0:1]
+
+
+def add_gaussian_noise(images, std):
+    noise = tf.random_normal(shape=tf.shape(images), mean=0.0, stddev=std, dtype=tf.float32)
+    return tf.nn.relu(images + noise)
+
+
+
+
+########### fftshift2D ###################
+## the same as fftshift in MATLAB
+## works for complex number
+def fft2dshift(input):
+    dim = int(input.shape[1].value)  # dimension of the data
+    channel1 = int(input.shape[0].value)  # channels for the first dimension
+    if dim % 2 == 0:
+        # even version
+        # shift up and down
+        u = tf.slice(input, [0, 0, 0], [channel1, int((dim) / 2), dim])
+        d = tf.slice(input, [0, int((dim) / 2), 0], [channel1, int((dim) / 2), dim])
+        du = tf.concat([d, u], axis=1)
+        # shift left and right
+        l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim) / 2)])
+        r = tf.slice(du, [0, 0, int((dim) / 2)], [channel1, dim, int((dim) / 2)])
+        output = tf.concat([r, l], axis=2)
+    else:
+        # odd version
+        # shift up and down
+        u = tf.slice(input, [0, 0, 0], [channel1, int((dim + 1) / 2), dim])
+        d = tf.slice(input, [0, int((dim + 1) / 2), 0], [channel1, int((dim - 1) / 2), dim])
+        du = tf.concat([d, u], axis=1)
+        # shift left and right
+        l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim + 1) / 2)])
+        r = tf.slice(du, [0, 0, int((dim + 1) / 2)], [channel1, dim, int((dim - 1) / 2)])
+        output = tf.concat([r, l], axis=2)
+    return output
+
+
+
+#########  generate out-of-focus phase  ###############
+## @param{Phi_list}: a list of Phi values
+## @param{N_B}: size of the blur kernel
+## @return{OOFphase} 
+def gen_OOFphase(Phi_list, N_B):
+    # return (Phi_list,pixel,pixel,color)
+    N = N_B
+    x0 = np.linspace(-2.84, 2.84, N) # 71/25 =2.84
+    xx, yy = np.meshgrid(x0, x0)
+    OOFphase = np.empty([len(Phi_list), N, N, 1], dtype=np.float32)
+    for j in range(len(Phi_list)):
+        Phi = Phi_list[j]
+        OOFphase[j, :, :, 0] = Phi * (xx ** 2 + yy ** 2)
+    return OOFphase
+
+
+
+##################  Generates the PSFs  ########################
+## @param{h}: height map of the mask
+## @param{OOFphase}: out-of-focus phase
+## @param{wvls}: wavelength \lambda
+## @param{idx}: index of the PSF
+## @param{N_B}: size of the blur kernel
+#################################################################
+def gen_PSFs(h, OOFphase, wvls, idx, N_B):
+    n = 1.5  # diffractive index
+
+    with tf.variable_scope("PSFs"):
+        OOFphase_B = OOFphase[:, :, :, 0]
+        phase_B = tf.add(2 * np.pi / wvls[0] * (n - 1) * h, OOFphase_B) # phase modulation of mask (phi_M)
+        Pupil_B = tf.multiply(tf.complex(idx, 0.0), tf.exp(tf.complex(0.0, phase_B)), name='Pupil_B') # pupil P
+        Norm_B = tf.cast(N_B * N_B * np.sum(idx ** 2), tf.float32)  # what's this?
+        PSF_B = tf.divide(tf.square(tf.abs(fft2dshift(tf.fft2d(Pupil_B)))), Norm_B, name='PSF_B')
+
+    return tf.expand_dims(PSF_B, -1)
+
+
+
+################  blur the images using PSFs  ##################
+## same patch different depths put in a stack
+################################################################
+def one_wvl_blur(im, PSFs0):
+    N_B = PSFs0.shape[1].value
+    N_Phi = PSFs0.shape[0].value
+    N_im = im.shape[1].value
+    N_im_out = N_im - N_B + 1  # the final image size after blurring
+
+    sharp = tf.transpose(tf.reshape(im, [-1, N_Phi, N_im, N_im]),
+                         [0, 2, 3, 1])  # reshape to make N_Phi in the last channel
+    PSFs = tf.expand_dims(tf.transpose(PSFs0, perm=[1, 2, 0]), -1)
+    blurAll = tf.nn.depthwise_conv2d(sharp, PSFs, strides=[1, 1, 1, 1], padding='VALID')
+    blurStack = tf.transpose(
+        tf.reshape(tf.transpose(blurAll, perm=[0, 3, 1, 2]), [-1, 1, N_im_out, N_im_out]),
+        perm=[0, 2, 3, 1])  # stack all N_Phi images to the first dimension
+
+    return blurStack
+
+
+def blurImage_diffPatch_diffDepth(RGB, PSFs):
+    blur = one_wvl_blur(RGB[:, :, :, 0], PSFs[:, :, :, 0])
+
+    return blur
+
+
+####################### system ##########################
+## @param{PSFs}: the PSFs
+## @param{RGB_batch_float}: patches
+## @param{phase_BN}: batch normalization, True only during training
+########################################################
+def system(PSFs, RGB_batch_float, phase_BN=True): 
+    with tf.variable_scope("system", reuse=tf.AUTO_REUSE):
+        blur = blurImage_diffPatch_diffDepth(RGB_batch_float, PSFs)  # size [batch_size * N_Phi, Nx, Ny, 3]
+
+        # noise
+        sigma = 0.01
+        blur_noisy = add_gaussian_noise(blur, sigma)
+
+        RGB_hat = Network.UNet(blur_noisy, phase_BN)
+
+        return blur, RGB_hat
+
+
+######################  RMS cost #############################
+## @param{GT}: ground truth
+## @param{hat}: reconstruction
+##############################################################
+def cost_rms(GT, hat):
+    cost = tf.sqrt(tf.reduce_mean(tf.square(GT - hat)))
+    return cost
+
+
+##########  compare the reconstruction reblured with U-net input?  ############
+## important for EDOF to utilize the PSF information
+## @param{RGB_hat}: Unet reconstructed image
+## @param{PSFs}: PSF used
+## @param{blur}: all-in-focus image conv PSF
+## @param{N_B}: size of blur kernel
+## @return{reblur}: reconstruction blurred
+## @return{cost}: l2 norm between blur_GT and reblur
+##############################################################################
+def cost_reblur(RGB_hat, PSFs, blur, N_B):
+    reblur = blurImage_diffPatch_diffDepth(RGB_hat, PSFs)
+    blur_GT = blur[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), int((N_B - 1) / 2):-int((N_B - 1) / 2), :] #crop the patch to 256x256
+
+    cost = tf.sqrt(tf.reduce_mean(tf.square(blur_GT - reblur)))
+
+    return reblur, cost
+
+
+######################################### Set parameters   ###############################################
+
+# def main():
+
+zernike = sio.loadmat('zernike_basis_150mm.mat')
+u2 = zernike['u2']  # basis of zernike poly
+idx = zernike['idx']
+idx = idx.astype(np.float32)
+
+a_zernike_mat = sio.loadmat('a_zernike_cubic_150mm.mat')
+a_zernike_fix = a_zernike_mat['a']
+a_zernike_fix = a_zernike_fix * 4
+a_zernike_fix = tf.convert_to_tensor(a_zernike_fix)
+
+N_B = 71  # size of the blur kernel
+wvls = np.array([550]) * 1e-9 # wavelength 550 nm
+N_color = len(wvls)
+
+N_modes = u2.shape[1]  # load zernike modes
+
+# generate the defocus phase
+N_Phi = 21
+Phi_list = np.linspace(-10, 10, N_Phi, np.float32) # defocus
+OOFphase = gen_OOFphase(Phi_list, N_B)  # return (N_Phi,N_B,N_B,N_color)
+
+# baseline offset for the heightmap
+c = 0
+
+####################################   Build the architecture  #####################################################
+
+
+with tf.variable_scope("PSFs"):
+    a_zernike_learn = tf.get_variable("a_zernike_learn", [N_modes, 1], initializer=tf.zeros_initializer(),
+                                constraint=lambda x: tf.clip_by_value(x, -wvls[0] / 2, wvls[0] / 2))
+    a_zernike = a_zernike_learn + a_zernike_fix # fixed cubic and learning part
+    g = tf.matmul(u2, a_zernike)
+    h = tf.nn.relu(tf.reshape(g, [N_B, N_B])+c, # c: baseline
+                   name='heightMap')  # height map of the phase mask, should be all positive
+    PSFs = gen_PSFs(h, OOFphase, wvls, idx, N_B)  # return (N_Phi, N_B, N_B, N_color)
+
+
+batch_size = N_Phi  # it means that each patch is blurred at different depth. Will be an error if this is not N_Phi
+
+
+RGB_batch_float = read2batch(TFRECORD_TRAIN_PATH, batch_size)
+RGB_batch_float_valid = read2batch(TFRECORD_VALID_PATH, batch_size)
+
+[blur_train, RGB_hat_train] = system(PSFs, RGB_batch_float)
+[blur_valid, RGB_hat_valid] = system(PSFs, RGB_batch_float_valid, phase_BN=False)
+
+# cost function
+with tf.name_scope("cost"):
+    RGB_GT_train = RGB_batch_float[:, int((N_B - 1) / 2):-int((N_B - 1) / 2),
+                   int((N_B - 1) / 2):-int((N_B - 1) / 2), :]                    # crop the all-in-focus to be 
+    RGB_GT_valid = RGB_batch_float_valid[:, int((N_B - 1) / 2):-int((N_B - 1) / 2),
+                   int((N_B - 1) / 2):-int((N_B - 1) / 2), :]
+
+    cost_rms_train = cost_rms(RGB_GT_train, RGB_hat_train)
+    cost_rms_valid = cost_rms(RGB_GT_valid, RGB_hat_valid)
+
+    cost_train = cost_rms_train
+    cost_valid = cost_rms_valid
+
+# train ditial and optical part saparetely
+vars_optical = tf.trainable_variables("PSFs")
+vars_digital = tf.trainable_variables("system")
+
+opt_optical = tf.train.AdamOptimizer(lr_optical)
+opt_digital = tf.train.AdamOptimizer(lr_digital)
+
+global_step = tf.Variable(0, name='global_step', trainable=False)  # initialize the stepsize
+
+update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # update the variables with gradient descent
+with tf.control_dependencies(update_ops):
+    grads = tf.gradients(cost_train, vars_optical + vars_digital)
+    grads_optical = grads[:len(vars_optical)]
+    grads_digital = grads[len(vars_optical):]
+    train_op_optical = opt_optical.apply_gradients(zip(grads_optical, vars_optical))
+    train_op_digital = opt_digital.apply_gradients(zip(grads_digital, vars_digital))
+    train_op = tf.group(train_op_optical, train_op_digital)
+
+# tensorboard
+tf.summary.scalar('cost_train', cost_train)
+tf.summary.scalar('cost_valid', cost_valid)
+tf.summary.scalar('cost_rms_train', cost_rms_train)
+tf.summary.scalar('cost_rms_valid', cost_rms_valid)
+
+tf.summary.histogram('a_zernike', a_zernike)
+tf.summary.histogram('a_zernike_learn', a_zernike_learn)
+tf.summary.histogram('a_zernike_fix', a_zernike_fix)
+tf.summary.image('Height', tf.expand_dims(tf.expand_dims(h, 0), -1))
+tf.summary.image('sharp_valid', tf.image.convert_image_dtype(RGB_GT_valid[0:1, :, :, :], dtype = tf.uint8))
+tf.summary.image('blur_valid', tf.image.convert_image_dtype(blur_valid[0:1, :, :, :], dtype = tf.uint8))
+tf.summary.image('RGB_hat_valid', tf.image.convert_image_dtype(RGB_hat_valid[0:1, :, :, :], dtype = tf.uint8))
+tf.summary.image('PSF_n100', PSFs[0:1,:,:])
+tf.summary.image('PSF_p90', PSFs[19:20,:,:])
+tf.summary.image('PSF_p100', PSFs[20:21,:,:])
+merged = tf.summary.merge_all()
+
+##########################################   Train  #############################################
+
+# variables_to_restore = [v for v in tf.global_variables() if v.name.startswith('system')]
+# saver = tf.train.Saver(variables_to_restore)
+saver_all = tf.train.Saver(max_to_keep=1)
+saver_best = tf.train.Saver()
+
+with tf.Session() as sess:
+    sess.run(tf.global_variables_initializer())
+    sess.run(tf.local_variables_initializer())
+
+    if not os.path.exists(results_dir):
+        os.makedirs(results_dir)
+
+    best_dir = 'best_model/'
+    if not os.path.exists(results_dir + best_dir):
+        os.makedirs(results_dir + best_dir)
+        best_valid_loss = 100
+    else:
+        best_valid_loss = np.loadtxt(results_dir + 'best_valid_loss.txt')
+        print('Current best valid loss = ' + str(best_valid_loss))
+
+    if not tf.train.checkpoint_exists(results_dir + 'checkpoint'):
+        # option1: run a new one
+        out_all = np.empty((0, 2))  # for out_all 4D: [train_loss,valid_loss,train_acc,valid_acc]
+        print('Start to save at: ', results_dir)
+    else:
+        print(results_dir)
+        model_path = tf.train.latest_checkpoint(results_dir)
+        load_path = saver_all.restore(sess, model_path)
+        out_all = np.load(results_dir + 'out_all.npy')
+        print('Continue to save at: ', results_dir)
+
+    train_writer = tf.summary.FileWriter(results_dir + '/summary/', sess.graph)
+
+    # threading for parallel 
+    coord = tf.train.Coordinator()
+    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+    for i in range(100000):
+        ## load the batch
+        train_op.run()  # only train digital 
+
+        if i != 0 and i % 10000 == 0:
+            lr_digital = lr_digital / 5  # reduce the learning rate every 10k
+
+        if i % 10 == 0:
+            [train_summary, loss_train, loss_valid, loss_rms_valid] = sess.run(
+                [merged, cost_train, cost_valid, cost_rms_valid])
+            train_writer.add_summary(train_summary, i)
+
+            print("Iter " + str(i) + ", Train Loss = " + \
+                  "{:.6f}".format(loss_train) + ", Valid Loss = " + \
+                  "{:.6f}".format(loss_valid))
+
+            # save them
+            out = np.array([[loss_train, loss_valid]])
+            out_all = np.vstack((out_all, out))
+            np.save(results_dir + 'out_all.npy', out_all)
+
+            saver_all.save(sess, results_dir + "model.ckpt", global_step=i)
+
+            [ht, at, PSFst] = sess.run([h, a_zernike, PSFs])
+            np.savetxt(results_dir + 'HeightMap.txt', ht)
+            np.savetxt(results_dir + 'a_zernike.txt', at)
+            np.save(results_dir + 'PSFs.npy', PSFst)
+
+            if (loss_rms_valid < best_valid_loss) and (i > 1):
+                best_valid_loss = loss_rms_valid
+                np.savetxt(results_dir + 'best_valid_loss.txt', [best_valid_loss])
+                saver_best.save(sess, results_dir + best_dir + "model.ckpt")
+                np.save(results_dir + best_dir + 'out_all.npy', out_all)
+                np.savetxt(results_dir + best_dir + 'HeightMap.txt', ht)
+                print('best at iter ' + str(i) + ' with loss = ' + str(best_valid_loss))
+
+    train_writer.close()
+    coord.request_stop()
+    coord.join(threads)