|
a |
|
b/DeepDOF_step2.py |
|
|
1 |
# End-to-end optimization for EDOF |
|
|
2 |
# Author: Yicheng Wu @ Rice University |
|
|
3 |
# 03/29/2019 |
|
|
4 |
# 04/12/2019 parameter update |
|
|
5 |
# 11/7/2019 update best model with valid_loss |
|
|
6 |
# 12/3/2019 update best model with valid_rms instead of valid_loss |
|
|
7 |
# 12/3/2019 update reblur cost = rms(blur, reblur) |
|
|
8 |
|
|
|
9 |
import tensorflow as tf |
|
|
10 |
import scipy.io as sio |
|
|
11 |
import numpy as np |
|
|
12 |
import os |
|
|
13 |
import Network |
|
|
14 |
|
|
|
15 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
16 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # only uses GPU 1 |
|
|
17 |
|
|
|
18 |
results_dir = "./results/" |
|
|
19 |
DATA_PATH = './DATA/' |
|
|
20 |
TFRECORD_TRAIN_PATH = [DATA_PATH + 'npo_720um_train.tfrecords'] # for testing purpose both are validation sets |
|
|
21 |
TFRECORD_VALID_PATH = [DATA_PATH + 'npo_720um_train.tfrecords'] |
|
|
22 |
|
|
|
23 |
## optimizer learning rates |
|
|
24 |
# use 0 in step 1: |
|
|
25 |
# lr_optical = 0 |
|
|
26 |
# use 1e-9 in step 2: |
|
|
27 |
lr_optical = 1e-9 |
|
|
28 |
lr_digital = 1e-4 |
|
|
29 |
print('lr_optical:' + str(lr_optical)) |
|
|
30 |
print('lr_digital:' + str(lr_digital)) |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
########################################## Functions ############################################# |
|
|
34 |
|
|
|
35 |
# Peak SNR, could be used as cost function |
|
|
36 |
def tf_PSNR(a, b, max_val, name=None): |
|
|
37 |
with tf.name_scope(name, 'PSNR', [a, b]): |
|
|
38 |
# Need to convert the images to float32. Scale max_val accordingly so that |
|
|
39 |
# PSNR is computed correctly. |
|
|
40 |
max_val = tf.cast(max_val, tf.float32) |
|
|
41 |
a = tf.cast(a, tf.float32) |
|
|
42 |
b = tf.cast(b, tf.float32) |
|
|
43 |
mse = tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) |
|
|
44 |
psnr_val = tf.subtract( |
|
|
45 |
20 * tf.log(max_val) / tf.log(10.0), |
|
|
46 |
np.float32(10 / np.log(10)) * tf.log(mse), |
|
|
47 |
name='psnr') |
|
|
48 |
|
|
|
49 |
return psnr_val |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
####### read from the TFRECORD format ################# |
|
|
53 |
## for faster reading from Hard disk |
|
|
54 |
def read_tfrecord(TFRECORD_PATH): |
|
|
55 |
# from tfrecord file to data |
|
|
56 |
N_w = 1000 # size of the images |
|
|
57 |
N_h = 1000 |
|
|
58 |
queue = tf.train.string_input_producer(TFRECORD_PATH, shuffle=True) |
|
|
59 |
reader = tf.TFRecordReader() |
|
|
60 |
|
|
|
61 |
_, serialized_example = reader.read(queue) |
|
|
62 |
|
|
|
63 |
features = tf.parse_single_example(serialized_example, |
|
|
64 |
features={ |
|
|
65 |
'sharp': tf.FixedLenFeature([], tf.string), |
|
|
66 |
}) |
|
|
67 |
|
|
|
68 |
RGB_flat = tf.decode_raw(features['sharp'], tf.uint8) |
|
|
69 |
RGB = tf.reshape(RGB_flat, [N_h, N_w, 1]) |
|
|
70 |
|
|
|
71 |
return RGB |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
|
|
|
75 |
########## Preprocess the images ############# |
|
|
76 |
## crop to patches |
|
|
77 |
## random flip |
|
|
78 |
## Add uniform noise |
|
|
79 |
############################################ |
|
|
80 |
def data_augment(RGB_batch_float): |
|
|
81 |
# crop to N_raw x N_raw |
|
|
82 |
N_raw = 326 # for boundary effect, 256+70, will need cropping after convolution |
|
|
83 |
data1 = tf.map_fn(lambda img: tf.random_crop(img, [N_raw, N_raw, 1]), RGB_batch_float) |
|
|
84 |
|
|
|
85 |
# flip both images and labels |
|
|
86 |
data2 = tf.map_fn(lambda img: tf.image.random_flip_up_down(tf.image.random_flip_left_right(img)), data1) |
|
|
87 |
|
|
|
88 |
# only adjust the RGB value of the image |
|
|
89 |
r1 = tf.random_uniform([]) * 0.3 + 0.8 |
|
|
90 |
RGB_out = data2 * r1 |
|
|
91 |
|
|
|
92 |
return RGB_out |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
|
|
|
96 |
############ Put data in batches ############# |
|
|
97 |
## put in batch and shuffle |
|
|
98 |
## cast to float32 |
|
|
99 |
## call data_augment for image preprocess |
|
|
100 |
## @param{TFRECORD_PATH}: path to the data |
|
|
101 |
## @param{batchsize}: currently 21 for the 21 PSFs |
|
|
102 |
############################################## |
|
|
103 |
def read2batch(TFRECORD_PATH, batchsize): |
|
|
104 |
# load tfrecord and make them to be usable data |
|
|
105 |
RGB = read_tfrecord(TFRECORD_PATH) |
|
|
106 |
RGB_batch = tf.train.shuffle_batch([RGB], batch_size=batchsize, capacity=200, |
|
|
107 |
min_after_dequeue=50, num_threads=5) |
|
|
108 |
RGB_batch_float = tf.image.convert_image_dtype(RGB_batch, tf.float32) |
|
|
109 |
|
|
|
110 |
RGB_batch_float = data_augment(RGB_batch_float) |
|
|
111 |
|
|
|
112 |
return RGB_batch_float[:,:,:,0:1] |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
def add_gaussian_noise(images, std): |
|
|
116 |
noise = tf.random_normal(shape=tf.shape(images), mean=0.0, stddev=std, dtype=tf.float32) |
|
|
117 |
return tf.nn.relu(images + noise) |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
|
|
|
121 |
|
|
|
122 |
########### fftshift2D ################### |
|
|
123 |
## the same as fftshift in MATLAB |
|
|
124 |
## works for complex number |
|
|
125 |
def fft2dshift(input): |
|
|
126 |
dim = int(input.shape[1].value) # dimension of the data |
|
|
127 |
channel1 = int(input.shape[0].value) # channels for the first dimension |
|
|
128 |
if dim % 2 == 0: |
|
|
129 |
# even version |
|
|
130 |
# shift up and down |
|
|
131 |
u = tf.slice(input, [0, 0, 0], [channel1, int((dim) / 2), dim]) |
|
|
132 |
d = tf.slice(input, [0, int((dim) / 2), 0], [channel1, int((dim) / 2), dim]) |
|
|
133 |
du = tf.concat([d, u], axis=1) |
|
|
134 |
# shift left and right |
|
|
135 |
l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim) / 2)]) |
|
|
136 |
r = tf.slice(du, [0, 0, int((dim) / 2)], [channel1, dim, int((dim) / 2)]) |
|
|
137 |
output = tf.concat([r, l], axis=2) |
|
|
138 |
else: |
|
|
139 |
# odd version |
|
|
140 |
# shift up and down |
|
|
141 |
u = tf.slice(input, [0, 0, 0], [channel1, int((dim + 1) / 2), dim]) |
|
|
142 |
d = tf.slice(input, [0, int((dim + 1) / 2), 0], [channel1, int((dim - 1) / 2), dim]) |
|
|
143 |
du = tf.concat([d, u], axis=1) |
|
|
144 |
# shift left and right |
|
|
145 |
l = tf.slice(du, [0, 0, 0], [channel1, dim, int((dim + 1) / 2)]) |
|
|
146 |
r = tf.slice(du, [0, 0, int((dim + 1) / 2)], [channel1, dim, int((dim - 1) / 2)]) |
|
|
147 |
output = tf.concat([r, l], axis=2) |
|
|
148 |
return output |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
|
|
|
152 |
######### generate out-of-focus phase ############### |
|
|
153 |
## @param{Phi_list}: a list of Phi values |
|
|
154 |
## @param{N_B}: size of the blur kernel |
|
|
155 |
## @return{OOFphase} |
|
|
156 |
def gen_OOFphase(Phi_list, N_B): |
|
|
157 |
# return (Phi_list,pixel,pixel,color) |
|
|
158 |
N = N_B |
|
|
159 |
x0 = np.linspace(-2.84, 2.84, N) # 71/25 =2.84 |
|
|
160 |
xx, yy = np.meshgrid(x0, x0) |
|
|
161 |
OOFphase = np.empty([len(Phi_list), N, N, 1], dtype=np.float32) |
|
|
162 |
for j in range(len(Phi_list)): |
|
|
163 |
Phi = Phi_list[j] |
|
|
164 |
OOFphase[j, :, :, 0] = Phi * (xx ** 2 + yy ** 2) |
|
|
165 |
return OOFphase |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
|
|
|
169 |
################## Generates the PSFs ######################## |
|
|
170 |
## @param{h}: height map of the mask |
|
|
171 |
## @param{OOFphase}: out-of-focus phase |
|
|
172 |
## @param{wvls}: wavelength \lambda |
|
|
173 |
## @param{idx}: index of the PSF |
|
|
174 |
## @param{N_B}: size of the blur kernel |
|
|
175 |
################################################################# |
|
|
176 |
def gen_PSFs(h, OOFphase, wvls, idx, N_B): |
|
|
177 |
n = 1.5 # diffractive index |
|
|
178 |
|
|
|
179 |
with tf.variable_scope("PSFs"): |
|
|
180 |
OOFphase_B = OOFphase[:, :, :, 0] |
|
|
181 |
phase_B = tf.add(2 * np.pi / wvls[0] * (n - 1) * h, OOFphase_B) # phase modulation of mask (phi_M) |
|
|
182 |
Pupil_B = tf.multiply(tf.complex(idx, 0.0), tf.exp(tf.complex(0.0, phase_B)), name='Pupil_B') # pupil P |
|
|
183 |
Norm_B = tf.cast(N_B * N_B * np.sum(idx ** 2), tf.float32) # what's this? |
|
|
184 |
PSF_B = tf.divide(tf.square(tf.abs(fft2dshift(tf.fft2d(Pupil_B)))), Norm_B, name='PSF_B') |
|
|
185 |
|
|
|
186 |
return tf.expand_dims(PSF_B, -1) |
|
|
187 |
|
|
|
188 |
|
|
|
189 |
|
|
|
190 |
################ blur the images using PSFs ################## |
|
|
191 |
## same patch different depths put in a stack |
|
|
192 |
################################################################ |
|
|
193 |
def one_wvl_blur(im, PSFs0): |
|
|
194 |
N_B = PSFs0.shape[1].value |
|
|
195 |
N_Phi = PSFs0.shape[0].value |
|
|
196 |
N_im = im.shape[1].value |
|
|
197 |
N_im_out = N_im - N_B + 1 # the final image size after blurring |
|
|
198 |
|
|
|
199 |
sharp = tf.transpose(tf.reshape(im, [-1, N_Phi, N_im, N_im]), |
|
|
200 |
[0, 2, 3, 1]) # reshape to make N_Phi in the last channel |
|
|
201 |
PSFs = tf.expand_dims(tf.transpose(PSFs0, perm=[1, 2, 0]), -1) |
|
|
202 |
blurAll = tf.nn.depthwise_conv2d(sharp, PSFs, strides=[1, 1, 1, 1], padding='VALID') |
|
|
203 |
blurStack = tf.transpose( |
|
|
204 |
tf.reshape(tf.transpose(blurAll, perm=[0, 3, 1, 2]), [-1, 1, N_im_out, N_im_out]), |
|
|
205 |
perm=[0, 2, 3, 1]) # stack all N_Phi images to the first dimension |
|
|
206 |
|
|
|
207 |
return blurStack |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
def blurImage_diffPatch_diffDepth(RGB, PSFs): |
|
|
211 |
blur = one_wvl_blur(RGB[:, :, :, 0], PSFs[:, :, :, 0]) |
|
|
212 |
|
|
|
213 |
return blur |
|
|
214 |
|
|
|
215 |
|
|
|
216 |
####################### system ########################## |
|
|
217 |
## @param{PSFs}: the PSFs |
|
|
218 |
## @param{RGB_batch_float}: patches |
|
|
219 |
## @param{phase_BN}: batch normalization, True only during training |
|
|
220 |
######################################################## |
|
|
221 |
def system(PSFs, RGB_batch_float, phase_BN=True): |
|
|
222 |
with tf.variable_scope("system", reuse=tf.AUTO_REUSE): |
|
|
223 |
blur = blurImage_diffPatch_diffDepth(RGB_batch_float, PSFs) # size [batch_size * N_Phi, Nx, Ny, 3] |
|
|
224 |
|
|
|
225 |
# noise |
|
|
226 |
sigma = 0.01 |
|
|
227 |
blur_noisy = add_gaussian_noise(blur, sigma) |
|
|
228 |
|
|
|
229 |
RGB_hat = Network.UNet(blur_noisy, phase_BN) |
|
|
230 |
|
|
|
231 |
return blur, RGB_hat |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
###################### RMS cost ############################# |
|
|
235 |
## @param{GT}: ground truth |
|
|
236 |
## @param{hat}: reconstruction |
|
|
237 |
############################################################## |
|
|
238 |
def cost_rms(GT, hat): |
|
|
239 |
cost = tf.sqrt(tf.reduce_mean(tf.square(GT - hat))) |
|
|
240 |
return cost |
|
|
241 |
|
|
|
242 |
|
|
|
243 |
########## compare the reconstruction reblured with U-net input? ############ |
|
|
244 |
## important for EDOF to utilize the PSF information |
|
|
245 |
## @param{RGB_hat}: Unet reconstructed image |
|
|
246 |
## @param{PSFs}: PSF used |
|
|
247 |
## @param{blur}: all-in-focus image conv PSF |
|
|
248 |
## @param{N_B}: size of blur kernel |
|
|
249 |
## @return{reblur}: reconstruction blurred |
|
|
250 |
## @return{cost}: l2 norm between blur_GT and reblur |
|
|
251 |
############################################################################## |
|
|
252 |
def cost_reblur(RGB_hat, PSFs, blur, N_B): |
|
|
253 |
reblur = blurImage_diffPatch_diffDepth(RGB_hat, PSFs) |
|
|
254 |
blur_GT = blur[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), int((N_B - 1) / 2):-int((N_B - 1) / 2), :] #crop the patch to 256x256 |
|
|
255 |
|
|
|
256 |
cost = tf.sqrt(tf.reduce_mean(tf.square(blur_GT - reblur))) |
|
|
257 |
|
|
|
258 |
return reblur, cost |
|
|
259 |
|
|
|
260 |
|
|
|
261 |
######################################### Set parameters ############################################### |
|
|
262 |
|
|
|
263 |
# def main(): |
|
|
264 |
|
|
|
265 |
zernike = sio.loadmat('zernike_basis_150mm.mat') |
|
|
266 |
u2 = zernike['u2'] # basis of zernike poly |
|
|
267 |
idx = zernike['idx'] |
|
|
268 |
idx = idx.astype(np.float32) |
|
|
269 |
|
|
|
270 |
a_zernike_mat = sio.loadmat('a_zernike_cubic_150mm.mat') |
|
|
271 |
a_zernike_fix = a_zernike_mat['a'] |
|
|
272 |
a_zernike_fix = a_zernike_fix * 4 |
|
|
273 |
a_zernike_fix = tf.convert_to_tensor(a_zernike_fix) |
|
|
274 |
|
|
|
275 |
N_B = 71 # size of the blur kernel |
|
|
276 |
wvls = np.array([550]) * 1e-9 # wavelength 550 nm |
|
|
277 |
N_color = len(wvls) |
|
|
278 |
|
|
|
279 |
N_modes = u2.shape[1] # load zernike modes |
|
|
280 |
|
|
|
281 |
# generate the defocus phase |
|
|
282 |
N_Phi = 21 |
|
|
283 |
Phi_list = np.linspace(-10, 10, N_Phi, np.float32) # defocus |
|
|
284 |
OOFphase = gen_OOFphase(Phi_list, N_B) # return (N_Phi,N_B,N_B,N_color) |
|
|
285 |
|
|
|
286 |
# baseline offset for the heightmap |
|
|
287 |
c = 0 |
|
|
288 |
|
|
|
289 |
#################################### Build the architecture ##################################################### |
|
|
290 |
|
|
|
291 |
|
|
|
292 |
with tf.variable_scope("PSFs"): |
|
|
293 |
a_zernike_learn = tf.get_variable("a_zernike_learn", [N_modes, 1], initializer=tf.zeros_initializer(), |
|
|
294 |
constraint=lambda x: tf.clip_by_value(x, -wvls[0] / 2, wvls[0] / 2)) |
|
|
295 |
a_zernike = a_zernike_learn + a_zernike_fix # fixed cubic and learning part |
|
|
296 |
g = tf.matmul(u2, a_zernike) |
|
|
297 |
h = tf.nn.relu(tf.reshape(g, [N_B, N_B])+c, # c: baseline |
|
|
298 |
name='heightMap') # height map of the phase mask, should be all positive |
|
|
299 |
PSFs = gen_PSFs(h, OOFphase, wvls, idx, N_B) # return (N_Phi, N_B, N_B, N_color) |
|
|
300 |
|
|
|
301 |
|
|
|
302 |
batch_size = N_Phi # it means that each patch is blurred at different depth. Will be an error if this is not N_Phi |
|
|
303 |
|
|
|
304 |
|
|
|
305 |
RGB_batch_float = read2batch(TFRECORD_TRAIN_PATH, batch_size) |
|
|
306 |
RGB_batch_float_valid = read2batch(TFRECORD_VALID_PATH, batch_size) |
|
|
307 |
|
|
|
308 |
[blur_train, RGB_hat_train] = system(PSFs, RGB_batch_float) |
|
|
309 |
[blur_valid, RGB_hat_valid] = system(PSFs, RGB_batch_float_valid, phase_BN=False) |
|
|
310 |
|
|
|
311 |
# cost function |
|
|
312 |
with tf.name_scope("cost"): |
|
|
313 |
RGB_GT_train = RGB_batch_float[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), |
|
|
314 |
int((N_B - 1) / 2):-int((N_B - 1) / 2), :] # crop the all-in-focus to be |
|
|
315 |
RGB_GT_valid = RGB_batch_float_valid[:, int((N_B - 1) / 2):-int((N_B - 1) / 2), |
|
|
316 |
int((N_B - 1) / 2):-int((N_B - 1) / 2), :] |
|
|
317 |
|
|
|
318 |
cost_rms_train = cost_rms(RGB_GT_train, RGB_hat_train) |
|
|
319 |
cost_rms_valid = cost_rms(RGB_GT_valid, RGB_hat_valid) |
|
|
320 |
cost_train = cost_rms_train |
|
|
321 |
cost_valid = cost_rms_valid |
|
|
322 |
|
|
|
323 |
# train ditial and optical part saparetely |
|
|
324 |
vars_optical = tf.trainable_variables("PSFs") |
|
|
325 |
vars_digital = tf.trainable_variables("system") |
|
|
326 |
|
|
|
327 |
opt_optical = tf.train.AdamOptimizer(lr_optical) |
|
|
328 |
opt_digital = tf.train.AdamOptimizer(lr_digital) |
|
|
329 |
|
|
|
330 |
global_step = tf.Variable(0, name='global_step', trainable=False) # initialize the stepsize |
|
|
331 |
|
|
|
332 |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # update the variables with gradient descent |
|
|
333 |
with tf.control_dependencies(update_ops): |
|
|
334 |
grads = tf.gradients(cost_train, vars_optical + vars_digital) |
|
|
335 |
grads_optical = grads[:len(vars_optical)] |
|
|
336 |
grads_digital = grads[len(vars_optical):] |
|
|
337 |
train_op_optical = opt_optical.apply_gradients(zip(grads_optical, vars_optical)) |
|
|
338 |
train_op_digital = opt_digital.apply_gradients(zip(grads_digital, vars_digital)) |
|
|
339 |
train_op = tf.group(train_op_optical, train_op_digital) |
|
|
340 |
|
|
|
341 |
# tensorboard |
|
|
342 |
tf.summary.scalar('cost_train', cost_train) |
|
|
343 |
tf.summary.scalar('cost_valid', cost_valid) |
|
|
344 |
tf.summary.scalar('cost_rms_train', cost_rms_train) |
|
|
345 |
tf.summary.scalar('cost_rms_valid', cost_rms_valid) |
|
|
346 |
|
|
|
347 |
tf.summary.histogram('a_zernike', a_zernike) |
|
|
348 |
tf.summary.histogram('a_zernike_learn', a_zernike_learn) |
|
|
349 |
tf.summary.histogram('a_zernike_fix', a_zernike_fix) |
|
|
350 |
tf.summary.image('Height', tf.expand_dims(tf.expand_dims(h, 0), -1)) |
|
|
351 |
tf.summary.image('sharp_valid', tf.image.convert_image_dtype(RGB_GT_valid[0:1, :, :, :], dtype = tf.uint8)) |
|
|
352 |
tf.summary.image('blur_valid', tf.image.convert_image_dtype(blur_valid[0:1, :, :, :], dtype = tf.uint8)) |
|
|
353 |
tf.summary.image('RGB_hat_valid', tf.image.convert_image_dtype(RGB_hat_valid[0:1, :, :, :], dtype = tf.uint8)) |
|
|
354 |
tf.summary.image('PSF_n100', PSFs[0:1,:,:]) |
|
|
355 |
tf.summary.image('PSF_p90', PSFs[19:20,:,:]) |
|
|
356 |
tf.summary.image('PSF_p100', PSFs[20:21,:,:]) |
|
|
357 |
merged = tf.summary.merge_all() |
|
|
358 |
|
|
|
359 |
########################################## Train ############################################# |
|
|
360 |
|
|
|
361 |
# variables_to_restore = [v for v in tf.global_variables() if v.name.startswith('system')] |
|
|
362 |
# saver = tf.train.Saver(variables_to_restore) |
|
|
363 |
saver_all = tf.train.Saver(max_to_keep=1) |
|
|
364 |
saver_best = tf.train.Saver() |
|
|
365 |
|
|
|
366 |
with tf.Session() as sess: |
|
|
367 |
sess.run(tf.global_variables_initializer()) |
|
|
368 |
sess.run(tf.local_variables_initializer()) |
|
|
369 |
|
|
|
370 |
if not os.path.exists(results_dir): |
|
|
371 |
os.makedirs(results_dir) |
|
|
372 |
|
|
|
373 |
best_dir = 'best_model/' |
|
|
374 |
if not os.path.exists(results_dir + best_dir): |
|
|
375 |
os.makedirs(results_dir + best_dir) |
|
|
376 |
best_valid_loss = 100 |
|
|
377 |
else: |
|
|
378 |
best_valid_loss = np.loadtxt(results_dir + 'best_valid_loss.txt') |
|
|
379 |
print('Current best valid loss = ' + str(best_valid_loss)) |
|
|
380 |
|
|
|
381 |
if not tf.train.checkpoint_exists(results_dir + 'checkpoint'): |
|
|
382 |
# option1: run a new one |
|
|
383 |
out_all = np.empty((0, 2)) # for out_all 4D: [train_loss,valid_loss,train_acc,valid_acc] |
|
|
384 |
print('Start to save at: ', results_dir) |
|
|
385 |
else: |
|
|
386 |
print(results_dir) |
|
|
387 |
model_path = tf.train.latest_checkpoint(results_dir) |
|
|
388 |
load_path = saver_all.restore(sess, model_path) |
|
|
389 |
out_all = np.load(results_dir + 'out_all.npy') |
|
|
390 |
print('Continue to save at: ', results_dir) |
|
|
391 |
|
|
|
392 |
train_writer = tf.summary.FileWriter(results_dir + '/summary/', sess.graph) |
|
|
393 |
|
|
|
394 |
# threading for parallel |
|
|
395 |
coord = tf.train.Coordinator() |
|
|
396 |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
|
|
397 |
|
|
|
398 |
for i in range(100000): |
|
|
399 |
## load the batch |
|
|
400 |
train_op.run() # only train digital part |
|
|
401 |
|
|
|
402 |
if i % 10 == 0: |
|
|
403 |
[train_summary, loss_train, loss_valid, loss_rms_valid] = sess.run( |
|
|
404 |
[merged, cost_train, cost_valid, cost_rms_valid]) |
|
|
405 |
train_writer.add_summary(train_summary, i) |
|
|
406 |
|
|
|
407 |
print("Iter " + str(i) + ", Train Loss = " + \ |
|
|
408 |
"{:.6f}".format(loss_train) + ", Valid Loss = " + \ |
|
|
409 |
"{:.6f}".format(loss_valid)) |
|
|
410 |
|
|
|
411 |
# save them |
|
|
412 |
out = np.array([[loss_train, loss_valid]]) |
|
|
413 |
out_all = np.vstack((out_all, out)) |
|
|
414 |
np.save(results_dir + 'out_all.npy', out_all) |
|
|
415 |
|
|
|
416 |
saver_all.save(sess, results_dir + "model.ckpt", global_step=i) |
|
|
417 |
|
|
|
418 |
[ht, at, PSFst] = sess.run([h, a_zernike, PSFs]) |
|
|
419 |
np.savetxt(results_dir + 'HeightMap.txt', ht) |
|
|
420 |
np.savetxt(results_dir + 'a_zernike.txt', at) |
|
|
421 |
np.save(results_dir + 'PSFs.npy', PSFst) |
|
|
422 |
|
|
|
423 |
if (loss_rms_valid < best_valid_loss) and (i > 1): |
|
|
424 |
best_valid_loss = loss_rms_valid |
|
|
425 |
np.savetxt(results_dir + 'best_valid_loss.txt', [best_valid_loss]) |
|
|
426 |
saver_best.save(sess, results_dir + best_dir + "model.ckpt") |
|
|
427 |
np.save(results_dir + best_dir + 'out_all.npy', out_all) |
|
|
428 |
np.savetxt(results_dir + best_dir + 'HeightMap.txt', ht) |
|
|
429 |
print('best at iter ' + str(i) + ' with loss = ' + str(best_valid_loss)) |
|
|
430 |
|
|
|
431 |
train_writer.close() |
|
|
432 |
coord.request_stop() |
|
|
433 |
coord.join(threads) |