--- a +++ b/DeepDOF_step1.py @@ -0,0 +1,437 @@ +# End-to-end optimization for EDOF +# Author: Yicheng Wu @ Rice University +# 03/29/2019 +# 04/12/2019 parameter update +# 11/7/2019 update best model with valid_loss +# 12/3/2019 update best model with valid_rms instead of valid_loss +# 12/3/2019 update reblur cost = rms(blur, reblur) + +import tensorflow as tf +import scipy.io as sio +import numpy as np +import os +import Network + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # only uses GPU 1 + +results_dir = "./results/" +DATA_PATH = "./DATA/" +TFRECORD_TRAIN_PATH = [DATA_PATH + 'npo_720um_train.tfrecords'] # for testing purpose both are validation sets +TFRECORD_VALID_PATH = [DATA_PATH + 'npo_720um_train.tfrecords'] + +## optimizer learning rates +# use 0 in step 1: +lr_optical = 0 +# use 1e-9 in step 2: +# lr_optical = 1e-9 +lr_digital = 1e-4 +print('lr_optical:' + str(lr_optical)) +print('lr_digital:' + str(lr_digital)) + + +########################################## Functions ############################################# + +# Peak SNR, could be used as cost function +def tf_PSNR(a, b, max_val, name=None): + with tf.name_scope(name, 'PSNR', [a, b]): + # Need to convert the images to float32. Scale max_val accordingly so that + # PSNR is computed correctly. + max_val = tf.cast(max_val, tf.float32) + a = tf.cast(a, tf.float32) + b = tf.cast(b, tf.float32) + mse = tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) + psnr_val = tf.subtract( + 20 * tf.log(max_val) / tf.log(10.0), + np.float32(10 / np.log(10)) * tf.log(mse), + name='psnr') + + return psnr_val + + +####### read from the TFRECORD format ################# +## for faster reading from Hard disk +def read_tfrecord(TFRECORD_PATH): + # from tfrecord file to data + N_w = 1000 # size of the images + N_h = 1000 + queue = tf.train.string_input_producer(TFRECORD_PATH, shuffle=True) + reader = tf.TFRecordReader() + + _, serialized_example = reader.read(queue) + + features = tf.parse_single_example(serialized_example, + features={ + 'sharp': tf.FixedLenFeature([], tf.string), + }) + + RGB_flat = tf.decode_raw(features['sharp'], tf.uint8) + RGB = tf.reshape(RGB_flat, [N_h, N_w, 1]) + + return RGB + + + +########## Preprocess the images ############# +## crop to patches +## random flip +## Add uniform noise +############################################ +def data_augment(RGB_batch_float): + # crop to N_raw x N_raw + N_raw = 326 # for boundary effect, 256+70, will need cropping after convolution + data1 = tf.map_fn(lambda img: tf.random_crop(img, [N_raw, N_raw, 1]), RGB_batch_float) + + # flip both images and labels + data2 = tf.map_fn(lambda img: tf.image.random_flip_up_down(tf.image.random_flip_left_right(img)), data1) + + # only adjust the RGB value of the image + r1 = tf.random_uniform([]) * 0.3 + 0.8 + RGB_out = data2 * r1 + + return RGB_out + + + +############ Put data in batches ############# +## put in batch and shuffle +## cast to float32 +## call data_augment for image preprocess +## @param{TFRECORD_PATH}: path to the data +## @param{batchsize}: currently 21 for the 21 PSFs +############################################## +def read2batch(TFRECORD_PATH, batchsize): + # load tfrecord and make them to be usable data + RGB = read_tfrecord(TFRECORD_PATH) + RGB_batch = tf.train.shuffle_batch([RGB], batch_size=batchsize, capacity=200, + min_after_dequeue=50, num_threads=5) + RGB_batch_float = tf.image.convert_image_dtype(RGB_batch, tf.float32) + + RGB_batch_float = data_augment(RGB_batch_float) + + return RGB_batch_float[:,:,:,0:1] + + +def add_gaussian_noise(images, std): + noise = tf.random_normal(shape=tf.shape(images), mean=0.0, stddev=std, dtype=tf.float32) + return tf.nn.relu(images + noise) + + + + +########### fftshift2D ################### +## the same as fftshift in MATLAB +## works for complex number +def fft2dshift(input): + dim = int(input.shape[1].value) # dimension of the data + channel1 = int(input.shape[0].value) # channels for the first dimension + if dim % 2 == 0: + # even version + # shift up and down + u = tf.slice(input, [0, 0, 0], [channel1, int((dim) / 2), dim]) + d = tf.slice(input, [0, int((dim) / 2), 0], [channel1, int((dim) / 2), dim]) + du = tf.concat([d, u], axis=1) + # shift left and right + l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim) / 2)]) + r = tf.slice(du, [0, 0, int((dim) / 2)], [channel1, dim, int((dim) / 2)]) + output = tf.concat([r, l], axis=2) + else: + # odd version + # shift up and down + u = tf.slice(input, [0, 0, 0], [channel1, int((dim + 1) / 2), dim]) + d = tf.slice(input, [0, int((dim + 1) / 2), 0], [channel1, int((dim - 1) / 2), dim]) + du = tf.concat([d, u], axis=1) + # shift left and right + l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim + 1) / 2)]) + r = tf.slice(du, [0, 0, int((dim + 1) / 2)], [channel1, dim, int((dim - 1) / 2)]) + output = tf.concat([r, l], axis=2) + return output + + + +######### generate out-of-focus phase ############### +## @param{Phi_list}: a list of Phi values +## @param{N_B}: size of the blur kernel +## @return{OOFphase} +def gen_OOFphase(Phi_list, N_B): + # return (Phi_list,pixel,pixel,color) + N = N_B + x0 = np.linspace(-2.84, 2.84, N) # 71/25 =2.84 + xx, yy = np.meshgrid(x0, x0) + OOFphase = np.empty([len(Phi_list), N, N, 1], dtype=np.float32) + for j in range(len(Phi_list)): + Phi = Phi_list[j] + OOFphase[j, :, :, 0] = Phi * (xx ** 2 + yy ** 2) + return OOFphase + + + +################## Generates the PSFs ######################## +## @param{h}: height map of the mask +## @param{OOFphase}: out-of-focus phase +## @param{wvls}: wavelength \lambda +## @param{idx}: index of the PSF +## @param{N_B}: size of the blur kernel +################################################################# +def gen_PSFs(h, OOFphase, wvls, idx, N_B): + n = 1.5 # diffractive index + + with tf.variable_scope("PSFs"): + OOFphase_B = OOFphase[:, :, :, 0] + phase_B = tf.add(2 * np.pi / wvls[0] * (n - 1) * h, OOFphase_B) # phase modulation of mask (phi_M) + Pupil_B = tf.multiply(tf.complex(idx, 0.0), tf.exp(tf.complex(0.0, phase_B)), name='Pupil_B') # pupil P + Norm_B = tf.cast(N_B * N_B * np.sum(idx ** 2), tf.float32) # what's this? + PSF_B = tf.divide(tf.square(tf.abs(fft2dshift(tf.fft2d(Pupil_B)))), Norm_B, name='PSF_B') + + return tf.expand_dims(PSF_B, -1) + + + +################ blur the images using PSFs ################## +## same patch different depths put in a stack +################################################################ +def one_wvl_blur(im, PSFs0): + N_B = PSFs0.shape[1].value + N_Phi = PSFs0.shape[0].value + N_im = im.shape[1].value + N_im_out = N_im - N_B + 1 # the final image size after blurring + + sharp = tf.transpose(tf.reshape(im, [-1, N_Phi, N_im, N_im]), + [0, 2, 3, 1]) # reshape to make N_Phi in the last channel + PSFs = tf.expand_dims(tf.transpose(PSFs0, perm=[1, 2, 0]), -1) + blurAll = tf.nn.depthwise_conv2d(sharp, PSFs, strides=[1, 1, 1, 1], padding='VALID') + blurStack = tf.transpose( + tf.reshape(tf.transpose(blurAll, perm=[0, 3, 1, 2]), [-1, 1, N_im_out, N_im_out]), + perm=[0, 2, 3, 1]) # stack all N_Phi images to the first dimension + + return blurStack + + +def blurImage_diffPatch_diffDepth(RGB, PSFs): + blur = one_wvl_blur(RGB[:, :, :, 0], PSFs[:, :, :, 0]) + + return blur + + +####################### system ########################## +## @param{PSFs}: the PSFs +## @param{RGB_batch_float}: patches +## @param{phase_BN}: batch normalization, True only during training +######################################################## +def system(PSFs, RGB_batch_float, phase_BN=True): + with tf.variable_scope("system", reuse=tf.AUTO_REUSE): + blur = blurImage_diffPatch_diffDepth(RGB_batch_float, PSFs) # size [batch_size * N_Phi, Nx, Ny, 3] + + # noise + sigma = 0.01 + blur_noisy = add_gaussian_noise(blur, sigma) + + RGB_hat = Network.UNet(blur_noisy, phase_BN) + + return blur, RGB_hat + + +###################### RMS cost ############################# +## @param{GT}: ground truth +## @param{hat}: reconstruction +############################################################## +def cost_rms(GT, hat): + cost = tf.sqrt(tf.reduce_mean(tf.square(GT - hat))) + return cost + + +########## compare the reconstruction reblured with U-net input? ############ +## important for EDOF to utilize the PSF information +## @param{RGB_hat}: Unet reconstructed image +## @param{PSFs}: PSF used +## @param{blur}: all-in-focus image conv PSF +## @param{N_B}: size of blur kernel +## @return{reblur}: reconstruction blurred +## @return{cost}: l2 norm between blur_GT and reblur +############################################################################## +def cost_reblur(RGB_hat, PSFs, blur, N_B): + reblur = blurImage_diffPatch_diffDepth(RGB_hat, PSFs) + blur_GT = blur[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), int((N_B - 1) / 2):-int((N_B - 1) / 2), :] #crop the patch to 256x256 + + cost = tf.sqrt(tf.reduce_mean(tf.square(blur_GT - reblur))) + + return reblur, cost + + +######################################### Set parameters ############################################### + +# def main(): + +zernike = sio.loadmat('zernike_basis_150mm.mat') +u2 = zernike['u2'] # basis of zernike poly +idx = zernike['idx'] +idx = idx.astype(np.float32) + +a_zernike_mat = sio.loadmat('a_zernike_cubic_150mm.mat') +a_zernike_fix = a_zernike_mat['a'] +a_zernike_fix = a_zernike_fix * 4 +a_zernike_fix = tf.convert_to_tensor(a_zernike_fix) + +N_B = 71 # size of the blur kernel +wvls = np.array([550]) * 1e-9 # wavelength 550 nm +N_color = len(wvls) + +N_modes = u2.shape[1] # load zernike modes + +# generate the defocus phase +N_Phi = 21 +Phi_list = np.linspace(-10, 10, N_Phi, np.float32) # defocus +OOFphase = gen_OOFphase(Phi_list, N_B) # return (N_Phi,N_B,N_B,N_color) + +# baseline offset for the heightmap +c = 0 + +#################################### Build the architecture ##################################################### + + +with tf.variable_scope("PSFs"): + a_zernike_learn = tf.get_variable("a_zernike_learn", [N_modes, 1], initializer=tf.zeros_initializer(), + constraint=lambda x: tf.clip_by_value(x, -wvls[0] / 2, wvls[0] / 2)) + a_zernike = a_zernike_learn + a_zernike_fix # fixed cubic and learning part + g = tf.matmul(u2, a_zernike) + h = tf.nn.relu(tf.reshape(g, [N_B, N_B])+c, # c: baseline + name='heightMap') # height map of the phase mask, should be all positive + PSFs = gen_PSFs(h, OOFphase, wvls, idx, N_B) # return (N_Phi, N_B, N_B, N_color) + + +batch_size = N_Phi # it means that each patch is blurred at different depth. Will be an error if this is not N_Phi + + +RGB_batch_float = read2batch(TFRECORD_TRAIN_PATH, batch_size) +RGB_batch_float_valid = read2batch(TFRECORD_VALID_PATH, batch_size) + +[blur_train, RGB_hat_train] = system(PSFs, RGB_batch_float) +[blur_valid, RGB_hat_valid] = system(PSFs, RGB_batch_float_valid, phase_BN=False) + +# cost function +with tf.name_scope("cost"): + RGB_GT_train = RGB_batch_float[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), + int((N_B - 1) / 2):-int((N_B - 1) / 2), :] # crop the all-in-focus to be + RGB_GT_valid = RGB_batch_float_valid[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), + int((N_B - 1) / 2):-int((N_B - 1) / 2), :] + + cost_rms_train = cost_rms(RGB_GT_train, RGB_hat_train) + cost_rms_valid = cost_rms(RGB_GT_valid, RGB_hat_valid) + + cost_train = cost_rms_train + cost_valid = cost_rms_valid + +# train ditial and optical part saparetely +vars_optical = tf.trainable_variables("PSFs") +vars_digital = tf.trainable_variables("system") + +opt_optical = tf.train.AdamOptimizer(lr_optical) +opt_digital = tf.train.AdamOptimizer(lr_digital) + +global_step = tf.Variable(0, name='global_step', trainable=False) # initialize the stepsize + +update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # update the variables with gradient descent +with tf.control_dependencies(update_ops): + grads = tf.gradients(cost_train, vars_optical + vars_digital) + grads_optical = grads[:len(vars_optical)] + grads_digital = grads[len(vars_optical):] + train_op_optical = opt_optical.apply_gradients(zip(grads_optical, vars_optical)) + train_op_digital = opt_digital.apply_gradients(zip(grads_digital, vars_digital)) + train_op = tf.group(train_op_optical, train_op_digital) + +# tensorboard +tf.summary.scalar('cost_train', cost_train) +tf.summary.scalar('cost_valid', cost_valid) +tf.summary.scalar('cost_rms_train', cost_rms_train) +tf.summary.scalar('cost_rms_valid', cost_rms_valid) + +tf.summary.histogram('a_zernike', a_zernike) +tf.summary.histogram('a_zernike_learn', a_zernike_learn) +tf.summary.histogram('a_zernike_fix', a_zernike_fix) +tf.summary.image('Height', tf.expand_dims(tf.expand_dims(h, 0), -1)) +tf.summary.image('sharp_valid', tf.image.convert_image_dtype(RGB_GT_valid[0:1, :, :, :], dtype = tf.uint8)) +tf.summary.image('blur_valid', tf.image.convert_image_dtype(blur_valid[0:1, :, :, :], dtype = tf.uint8)) +tf.summary.image('RGB_hat_valid', tf.image.convert_image_dtype(RGB_hat_valid[0:1, :, :, :], dtype = tf.uint8)) +tf.summary.image('PSF_n100', PSFs[0:1,:,:]) +tf.summary.image('PSF_p90', PSFs[19:20,:,:]) +tf.summary.image('PSF_p100', PSFs[20:21,:,:]) +merged = tf.summary.merge_all() + +########################################## Train ############################################# + +# variables_to_restore = [v for v in tf.global_variables() if v.name.startswith('system')] +# saver = tf.train.Saver(variables_to_restore) +saver_all = tf.train.Saver(max_to_keep=1) +saver_best = tf.train.Saver() + +with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + if not os.path.exists(results_dir): + os.makedirs(results_dir) + + best_dir = 'best_model/' + if not os.path.exists(results_dir + best_dir): + os.makedirs(results_dir + best_dir) + best_valid_loss = 100 + else: + best_valid_loss = np.loadtxt(results_dir + 'best_valid_loss.txt') + print('Current best valid loss = ' + str(best_valid_loss)) + + if not tf.train.checkpoint_exists(results_dir + 'checkpoint'): + # option1: run a new one + out_all = np.empty((0, 2)) # for out_all 4D: [train_loss,valid_loss,train_acc,valid_acc] + print('Start to save at: ', results_dir) + else: + print(results_dir) + model_path = tf.train.latest_checkpoint(results_dir) + load_path = saver_all.restore(sess, model_path) + out_all = np.load(results_dir + 'out_all.npy') + print('Continue to save at: ', results_dir) + + train_writer = tf.summary.FileWriter(results_dir + '/summary/', sess.graph) + + # threading for parallel + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + for i in range(100000): + ## load the batch + train_op.run() # only train digital + + if i != 0 and i % 10000 == 0: + lr_digital = lr_digital / 5 # reduce the learning rate every 10k + + if i % 10 == 0: + [train_summary, loss_train, loss_valid, loss_rms_valid] = sess.run( + [merged, cost_train, cost_valid, cost_rms_valid]) + train_writer.add_summary(train_summary, i) + + print("Iter " + str(i) + ", Train Loss = " + \ + "{:.6f}".format(loss_train) + ", Valid Loss = " + \ + "{:.6f}".format(loss_valid)) + + # save them + out = np.array([[loss_train, loss_valid]]) + out_all = np.vstack((out_all, out)) + np.save(results_dir + 'out_all.npy', out_all) + + saver_all.save(sess, results_dir + "model.ckpt", global_step=i) + + [ht, at, PSFst] = sess.run([h, a_zernike, PSFs]) + np.savetxt(results_dir + 'HeightMap.txt', ht) + np.savetxt(results_dir + 'a_zernike.txt', at) + np.save(results_dir + 'PSFs.npy', PSFst) + + if (loss_rms_valid < best_valid_loss) and (i > 1): + best_valid_loss = loss_rms_valid + np.savetxt(results_dir + 'best_valid_loss.txt', [best_valid_loss]) + saver_best.save(sess, results_dir + best_dir + "model.ckpt") + np.save(results_dir + best_dir + 'out_all.npy', out_all) + np.savetxt(results_dir + best_dir + 'HeightMap.txt', ht) + print('best at iter ' + str(i) + ' with loss = ' + str(best_valid_loss)) + + train_writer.close() + coord.request_stop() + coord.join(threads)