|
a |
|
b/Medical-Image-Segmentation_DCGAN.py |
|
|
1 |
|
|
|
2 |
# coding: utf-8 |
|
|
3 |
|
|
|
4 |
# In[1]: |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
import os |
|
|
8 |
from medpy.io import load |
|
|
9 |
import numpy as np |
|
|
10 |
import cv2 as cv |
|
|
11 |
from PIL import Image |
|
|
12 |
|
|
|
13 |
PATH = os.path.abspath("E:/UB CSE/Spring 2018/700/Project/BRATS2013/BRATS_Training/BRATS-2/Image_Data") |
|
|
14 |
|
|
|
15 |
# pad image to standardize size, then crop down as required (by memory constraints) |
|
|
16 |
def pad_image(img, desired_shape=(256, 256)): |
|
|
17 |
pad_top = 0 |
|
|
18 |
pad_bot = 0 |
|
|
19 |
pad_left = 0 |
|
|
20 |
pad_right = 0 |
|
|
21 |
if desired_shape[0] > img.shape[0]: |
|
|
22 |
pad_top = int((desired_shape[0] - img.shape[0]) / 2) |
|
|
23 |
pad_bot = desired_shape[0] - img.shape[0] - pad_top |
|
|
24 |
if desired_shape[1] > img.shape[1]: |
|
|
25 |
pad_left = int((desired_shape[1] - img.shape[1]) / 2) |
|
|
26 |
pad_right = desired_shape[1] - img.shape[1] - pad_left |
|
|
27 |
img = np.pad(img, ((pad_top, pad_bot), (pad_left, pad_right)), 'constant') |
|
|
28 |
|
|
|
29 |
img = img[50:200,50:200] |
|
|
30 |
img = cv.resize(img, dsize=(28,28), interpolation=cv.INTER_CUBIC) |
|
|
31 |
|
|
|
32 |
return img |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def normalize(img): |
|
|
36 |
nimg = None |
|
|
37 |
nimg = cv.normalize(img.astype('float'), nimg, alpha=0.0, beta=1.0, norm_type=cv.NORM_MINMAX) |
|
|
38 |
nimg = pad_image(nimg, desired_shape=(256, 256)) |
|
|
39 |
nimg.round(decimals=2) |
|
|
40 |
return nimg |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def load_single_image(path): |
|
|
44 |
for dir, subdir, files in os.walk(path): |
|
|
45 |
for file in files: |
|
|
46 |
if file.endswith(".mha"): |
|
|
47 |
img = load_itk(os.path.join(path, file)) |
|
|
48 |
return img |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def create_1_chan_data(flair, ot): |
|
|
52 |
ot_layers = [] |
|
|
53 |
flair_layers = [] |
|
|
54 |
# print("OT shape",ot.shape[2]) |
|
|
55 |
for layer in range(ot.shape[2]): |
|
|
56 |
ot_layers.append(pad_image(ot[:, :, layer], desired_shape=(256, 256))) |
|
|
57 |
# print("Flair intensities: ", np.unique(flair[:, :, layer])) |
|
|
58 |
normalizedImage = normalize(flair[:, :, layer]) |
|
|
59 |
# print("Normalized Image intensities: ", np.unique(normalizedImage)) |
|
|
60 |
flair_layers.append(normalizedImage) |
|
|
61 |
|
|
|
62 |
return np.stack(ot_layers, axis=0), np.stack(flair_layers, axis=0) |
|
|
63 |
|
|
|
64 |
# BRaTS dataset contains 4 channels of input data and one channel of groundtruth for a 3D brain scan image. |
|
|
65 |
def load_dataset(path): |
|
|
66 |
|
|
|
67 |
train_flair = [] |
|
|
68 |
train_ot = [] |
|
|
69 |
|
|
|
70 |
for dir in os.listdir(path): |
|
|
71 |
if dir == 'HG': |
|
|
72 |
HG_path = os.path.join(path, 'HG') |
|
|
73 |
for dir2 in os.listdir(HG_path): |
|
|
74 |
if dir2 != '.DS_Store': |
|
|
75 |
HG_flair = load_single_image(os.path.join(HG_path, dir2, 'VSD.Brain.XX.O.MR_Flair')) |
|
|
76 |
HG_ot = load_single_image(os.path.join(HG_path, dir2, 'VSD.Brain_3more.XX.XX.OT')) |
|
|
77 |
assert (HG_ot.shape == HG_flair.shape ) |
|
|
78 |
HG_samples = create_1_chan_data(HG_flair, HG_ot) |
|
|
79 |
train_ot.append(HG_samples[0]) |
|
|
80 |
train_flair.append(HG_samples[1]) |
|
|
81 |
|
|
|
82 |
if dir == 'LG': |
|
|
83 |
brain_1 = brain_2 = brain_3 = False |
|
|
84 |
LG_path = os.path.join(path, 'LG') |
|
|
85 |
for dir3 in os.listdir(LG_path): |
|
|
86 |
if dir3 != '.DS_Store': |
|
|
87 |
LG_flair = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain.XX.O.MR_Flair')) |
|
|
88 |
brain_1 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_1more.XX.XX.OT')) |
|
|
89 |
brain_2 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_2more.XX.XX.OT')) |
|
|
90 |
brain_3 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_3more.XX.XX.OT')) |
|
|
91 |
if brain_1: |
|
|
92 |
LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_1more.XX.XX.OT')) |
|
|
93 |
if brain_2: |
|
|
94 |
LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_2more.XX.XX.OT')) |
|
|
95 |
if brain_3: |
|
|
96 |
LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_3more.XX.XX.OT')) |
|
|
97 |
|
|
|
98 |
assert (LG_ot.shape == LG_flair.shape) |
|
|
99 |
LG_samples = create_1_chan_data(LG_flair, LG_ot) |
|
|
100 |
train_ot.append(LG_samples[0]) |
|
|
101 |
train_flair.append(LG_samples[1]) |
|
|
102 |
# Stacking all individual layers |
|
|
103 |
train_ot = np.vstack(train_ot) |
|
|
104 |
train_flair = np.vstack(train_flair) |
|
|
105 |
assert (train_ot.shape == train_flair.shape) |
|
|
106 |
return train_flair,train_ot |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
# In[2]: |
|
|
110 |
|
|
|
111 |
#SimpleITK is used for reading the brain scan images |
|
|
112 |
import SimpleITK as sitk |
|
|
113 |
import numpy as np |
|
|
114 |
import os |
|
|
115 |
import glob |
|
|
116 |
from medpy.io import load |
|
|
117 |
''' |
|
|
118 |
This funciton reads a '.mhd' file using SimpleITK and return the image array, origin and spacing of the image. |
|
|
119 |
''' |
|
|
120 |
|
|
|
121 |
def load_itk(filename): |
|
|
122 |
# Reads the image using SimpleITK |
|
|
123 |
itkimage = sitk.ReadImage(filename) |
|
|
124 |
|
|
|
125 |
# Convert the image to a numpy array first and then shuffle the dimensions to get axis in the order z,y,x |
|
|
126 |
ct_scan = sitk.GetArrayFromImage(itkimage) |
|
|
127 |
|
|
|
128 |
# Read the origin of the ct_scan, will be used to convert the coordinates from world to voxel and vice versa. |
|
|
129 |
origin = np.array(list(reversed(itkimage.GetOrigin()))) |
|
|
130 |
|
|
|
131 |
# Read the spacing along each dimension |
|
|
132 |
spacing = np.array(list(reversed(itkimage.GetSpacing()))) |
|
|
133 |
|
|
|
134 |
# return ct_scan, origin, spacing |
|
|
135 |
return ct_scan |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
# In[3]: |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
flair_data, ot_data =load_dataset(PATH) |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
# In[4]: |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
print(flair_data.shape) |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
# In[5]: |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
import matplotlib.pyplot as plt |
|
|
154 |
# fig1 = plt.figure() |
|
|
155 |
plt.imshow(ot_data[420,:,:]) |
|
|
156 |
plt.savefig('sample.png') |
|
|
157 |
plt.show() |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
# In[6]: |
|
|
161 |
|
|
|
162 |
|
|
|
163 |
print(np.unique(ot_data[420,:,:])) |
|
|
164 |
|
|
|
165 |
|
|
|
166 |
# In[7]: |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
# imginput = x[0] |
|
|
170 |
# imgoutput = x[1] |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
# In[8]: |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
print(flair_data.shape) |
|
|
177 |
|
|
|
178 |
|
|
|
179 |
# In[9]: |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
print(ot_data.shape) |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
# In[10]: |
|
|
186 |
|
|
|
187 |
|
|
|
188 |
np.amax(ot_data) |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
# # Experiment |
|
|
192 |
|
|
|
193 |
# In[11]: |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
import os |
|
|
197 |
from glob import glob |
|
|
198 |
from matplotlib import pyplot |
|
|
199 |
from PIL import Image |
|
|
200 |
import numpy as np |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
# Image configuration |
|
|
204 |
IMAGE_HEIGHT = 28 |
|
|
205 |
IMAGE_WIDTH = 28 |
|
|
206 |
data_files = PATH |
|
|
207 |
# shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT,1 |
|
|
208 |
shape = flair_data.shape[0],flair_data.shape[1],flair_data.shape[2],1 |
|
|
209 |
print(shape) |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
# In[12]: |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
|
|
|
216 |
def get_batches(batch_size): |
|
|
217 |
""" |
|
|
218 |
Generate batches |
|
|
219 |
""" |
|
|
220 |
# IMAGE_MAX_VALUE = 255 |
|
|
221 |
|
|
|
222 |
|
|
|
223 |
current_index = 0 |
|
|
224 |
while current_index + batch_size <= shape[0]: |
|
|
225 |
|
|
|
226 |
data_batch = (ot_data[current_index:current_index + batch_size]) |
|
|
227 |
z_batch = (flair_data[current_index:current_index + batch_size]) |
|
|
228 |
#print(type(data_batch)) |
|
|
229 |
#print(data_batch.shape) |
|
|
230 |
data_batch = data_batch[...,np.newaxis] |
|
|
231 |
#print(data_batch.shape) |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
# np.vstack((data_batch, x[1,current_index:current_index + batch_size])) |
|
|
235 |
|
|
|
236 |
|
|
|
237 |
|
|
|
238 |
current_index += batch_size |
|
|
239 |
|
|
|
240 |
# return data_batch / IMAGE_MAX_VALUE - 0.5 |
|
|
241 |
|
|
|
242 |
# yield data_batch / IMAGE_MAX_VALUE - 0.5 |
|
|
243 |
#print("db:",data_batch.shape) |
|
|
244 |
yield data_batch, z_batch |
|
|
245 |
|
|
|
246 |
|
|
|
247 |
# In[13]: |
|
|
248 |
|
|
|
249 |
|
|
|
250 |
print(get_batches(4)) |
|
|
251 |
|
|
|
252 |
|
|
|
253 |
# In[14]: |
|
|
254 |
|
|
|
255 |
|
|
|
256 |
import tensorflow as tf |
|
|
257 |
|
|
|
258 |
def model_inputs(image_width, image_height, image_channels, z_dim): |
|
|
259 |
""" |
|
|
260 |
Create the model inputs |
|
|
261 |
""" |
|
|
262 |
inputs_real = tf.placeholder(tf.float32, shape=(None, image_width, image_height, image_channels), name='input_real') |
|
|
263 |
inputs_z = tf.placeholder(tf.float32, shape=(None,z_dim), name='input_z') |
|
|
264 |
learning_rate = tf.placeholder(tf.float32, name='learning_rate') |
|
|
265 |
|
|
|
266 |
return inputs_real, inputs_z, learning_rate |
|
|
267 |
|
|
|
268 |
|
|
|
269 |
# In[15]: |
|
|
270 |
|
|
|
271 |
|
|
|
272 |
def discriminator(images, reuse=False): |
|
|
273 |
""" |
|
|
274 |
Create the discriminator network |
|
|
275 |
""" |
|
|
276 |
alpha = 0.2 |
|
|
277 |
#print("image size:",images.shape) |
|
|
278 |
with tf.variable_scope('discriminator', reuse=reuse): |
|
|
279 |
# using 4 layer network as in DCGAN Paper |
|
|
280 |
|
|
|
281 |
# Conv 1 |
|
|
282 |
conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME') |
|
|
283 |
lrelu1 = tf.maximum(alpha * conv1, conv1) |
|
|
284 |
# print("layer1:",lrelu1.shape) |
|
|
285 |
|
|
|
286 |
# Conv 2 |
|
|
287 |
conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME') |
|
|
288 |
batch_norm2 = tf.layers.batch_normalization(conv2, training=True) |
|
|
289 |
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2) |
|
|
290 |
# print("layer2:",lrelu2.shape) |
|
|
291 |
|
|
|
292 |
# Conv 3 |
|
|
293 |
conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME') |
|
|
294 |
batch_norm3 = tf.layers.batch_normalization(conv3, training=True) |
|
|
295 |
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3) |
|
|
296 |
# print("layer3:",lrelu3.shape) |
|
|
297 |
|
|
|
298 |
# Flatten |
|
|
299 |
flat = tf.reshape(lrelu3, (-1, 1*1*256)) |
|
|
300 |
# print("layer4:",flat.shape) |
|
|
301 |
|
|
|
302 |
# Logits |
|
|
303 |
logits = tf.layers.dense(flat, 1) |
|
|
304 |
|
|
|
305 |
# Output |
|
|
306 |
out = tf.sigmoid(logits) |
|
|
307 |
|
|
|
308 |
return out, logits |
|
|
309 |
|
|
|
310 |
|
|
|
311 |
# In[16]: |
|
|
312 |
|
|
|
313 |
|
|
|
314 |
def generator(z, out_channel_dim, is_train=True): |
|
|
315 |
""" |
|
|
316 |
Create the generator network |
|
|
317 |
""" |
|
|
318 |
alpha = 0.2 |
|
|
319 |
# print("gen,z:",z.shape) |
|
|
320 |
|
|
|
321 |
with tf.variable_scope('generator', reuse=False if is_train==True else True): |
|
|
322 |
|
|
|
323 |
# using 4 layer network as in DCGAN Paper |
|
|
324 |
|
|
|
325 |
# First fully connected layer |
|
|
326 |
x_1 = tf.layers.dense(z, 2*2*512) |
|
|
327 |
#print("Gen,fully conn layer 1:",x_1.shape) |
|
|
328 |
|
|
|
329 |
# Reshape it to start the convolutional stack |
|
|
330 |
deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512)) |
|
|
331 |
batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train) |
|
|
332 |
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2) |
|
|
333 |
#print("Gen,fully conn layer 1 reshape: ",lrelu2.shape) |
|
|
334 |
|
|
|
335 |
|
|
|
336 |
# Deconv 1 |
|
|
337 |
deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID') |
|
|
338 |
batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train) |
|
|
339 |
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3) |
|
|
340 |
#print("Gen,deconv layer 1 : ",lrelu3.shape) |
|
|
341 |
|
|
|
342 |
|
|
|
343 |
# Deconv 2 |
|
|
344 |
deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME') |
|
|
345 |
batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train) |
|
|
346 |
lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4) |
|
|
347 |
#print("Gen,deconv layer 2 : ",lrelu4.shape) |
|
|
348 |
|
|
|
349 |
# Output layer |
|
|
350 |
logits = tf.layers.conv2d_transpose(lrelu4, out_channel_dim, 5, 2, padding='SAME') |
|
|
351 |
#print("Gen,output layer : ",logits.shape) |
|
|
352 |
|
|
|
353 |
out = tf.tanh(logits) |
|
|
354 |
|
|
|
355 |
return out |
|
|
356 |
|
|
|
357 |
|
|
|
358 |
# In[17]: |
|
|
359 |
|
|
|
360 |
|
|
|
361 |
def model_loss(input_real, input_z, out_channel_dim): |
|
|
362 |
""" |
|
|
363 |
Get the loss for the discriminator and generator |
|
|
364 |
""" |
|
|
365 |
|
|
|
366 |
label_smoothing = 0.9 |
|
|
367 |
|
|
|
368 |
g_model = generator(input_z, out_channel_dim) |
|
|
369 |
d_model_real, d_logits_real = discriminator(input_real) |
|
|
370 |
#print("gmodel size", g_model.shape) |
|
|
371 |
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True) |
|
|
372 |
|
|
|
373 |
|
|
|
374 |
# Change it to norm_l2 loss between generated groundtruth and actual groundtruth |
|
|
375 |
d_loss_real = tf.reduce_mean( |
|
|
376 |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, |
|
|
377 |
labels=tf.ones_like(d_model_real) * label_smoothing)) |
|
|
378 |
d_loss_fake = tf.reduce_mean( |
|
|
379 |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, |
|
|
380 |
labels=tf.zeros_like(d_model_fake))) |
|
|
381 |
|
|
|
382 |
d_loss = d_loss_real + d_loss_fake |
|
|
383 |
|
|
|
384 |
g_loss = tf.reduce_mean( |
|
|
385 |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, |
|
|
386 |
labels=tf.ones_like(d_model_fake) * label_smoothing)) |
|
|
387 |
|
|
|
388 |
|
|
|
389 |
return d_loss, g_loss |
|
|
390 |
|
|
|
391 |
|
|
|
392 |
# In[18]: |
|
|
393 |
|
|
|
394 |
|
|
|
395 |
def model_opt(d_loss, g_loss, learning_rate, beta1): |
|
|
396 |
""" |
|
|
397 |
Get optimization operations |
|
|
398 |
""" |
|
|
399 |
t_vars = tf.trainable_variables() |
|
|
400 |
d_vars = [var for var in t_vars if var.name.startswith('discriminator')] |
|
|
401 |
g_vars = [var for var in t_vars if var.name.startswith('generator')] |
|
|
402 |
|
|
|
403 |
# Optimize |
|
|
404 |
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): |
|
|
405 |
d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars) |
|
|
406 |
g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars) |
|
|
407 |
|
|
|
408 |
return d_train_opt, g_train_opt |
|
|
409 |
|
|
|
410 |
|
|
|
411 |
# In[19]: |
|
|
412 |
|
|
|
413 |
|
|
|
414 |
def show_generator_output(sess, n_images, input_z, out_channel_dim,counter): |
|
|
415 |
""" |
|
|
416 |
Show example output for the generator |
|
|
417 |
""" |
|
|
418 |
# z_dim = input_z.get_shape().as_list()[-1] |
|
|
419 |
# example_z = np.random.uniform(-1, 1, size=[n_images, z_dim]) |
|
|
420 |
example_z = np.reshape(flair_data[420,:,:],(1,IMAGE_WIDTH*IMAGE_HEIGHT)) |
|
|
421 |
samples = sess.run( |
|
|
422 |
generator(input_z, out_channel_dim, False), |
|
|
423 |
feed_dict={input_z: example_z}) |
|
|
424 |
|
|
|
425 |
#print("SAmples shape: ", samples.shape) |
|
|
426 |
pyplot.imshow(samples[0,:,:,0]) |
|
|
427 |
path = "out"+str(counter)+".png" |
|
|
428 |
pyplot.savefig(path) |
|
|
429 |
pyplot.show() |
|
|
430 |
|
|
|
431 |
|
|
|
432 |
# In[20]: |
|
|
433 |
|
|
|
434 |
|
|
|
435 |
def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape): |
|
|
436 |
""" |
|
|
437 |
Train the GAN |
|
|
438 |
""" |
|
|
439 |
input_real, input_z, _ = model_inputs(data_shape[1], data_shape[2], data_shape[3], z_dim) |
|
|
440 |
d_loss, g_loss = model_loss(input_real, input_z, data_shape[3]) |
|
|
441 |
d_opt, g_opt = model_opt(d_loss, g_loss, learning_rate, beta1) |
|
|
442 |
|
|
|
443 |
steps = 0 |
|
|
444 |
|
|
|
445 |
with tf.Session() as sess: |
|
|
446 |
sess.run(tf.global_variables_initializer()) |
|
|
447 |
for epoch_i in range(epoch_count): |
|
|
448 |
for batch_images,batch_z in get_batches(batch_size): |
|
|
449 |
|
|
|
450 |
# values range from -0.5 to 0.5, therefore scale to range -1, 1 |
|
|
451 |
# batch_images = batch_images * 2 |
|
|
452 |
steps += 1 |
|
|
453 |
batch_z = np.reshape(batch_z,(batch_size, IMAGE_WIDTH*IMAGE_HEIGHT)) |
|
|
454 |
# batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim) |
|
|
455 |
#print("Batch:",batch_images.shape) |
|
|
456 |
#print("Batch Z:",batch_z.shape) |
|
|
457 |
|
|
|
458 |
_ = sess.run(d_opt, feed_dict={input_real: batch_images, input_z: batch_z}) |
|
|
459 |
_ = sess.run(g_opt, feed_dict={input_real: batch_images, input_z: batch_z}) |
|
|
460 |
counter = 0 |
|
|
461 |
if steps % 400 == 0: |
|
|
462 |
counter = counter+1 |
|
|
463 |
# At the end of every 10 epochs, get the losses and print them out |
|
|
464 |
train_loss_d = d_loss.eval({input_z: batch_z, input_real: batch_images}) |
|
|
465 |
train_loss_g = g_loss.eval({input_z: batch_z}) |
|
|
466 |
|
|
|
467 |
print("Epoch {}/{}...".format(epoch_i+1, epochs), |
|
|
468 |
"Discriminator Loss: {:.4f}...".format(train_loss_d), |
|
|
469 |
"Generator Loss: {:.4f}".format(train_loss_g)) |
|
|
470 |
|
|
|
471 |
_ = show_generator_output(sess, 1, input_z, data_shape[3],(steps/40)) |
|
|
472 |
|
|
|
473 |
|
|
|
474 |
# In[21]: |
|
|
475 |
|
|
|
476 |
|
|
|
477 |
#### import tensorflow as tf |
|
|
478 |
batch_size = 5 |
|
|
479 |
z_dim = 784 |
|
|
480 |
learning_rate = 0.0002 |
|
|
481 |
beta1 = 0.5 |
|
|
482 |
epochs = 100 |
|
|
483 |
|
|
|
484 |
with tf.Graph().as_default(): |
|
|
485 |
train(epochs, batch_size, z_dim, learning_rate, beta1, get_batches, shape) |
|
|
486 |
|