Switch to unified view

a b/Medical-Image-Segmentation_DCGAN.py
1
2
# coding: utf-8
3
4
# In[1]:
5
6
7
import os
8
from medpy.io import load
9
import numpy as np
10
import cv2 as cv
11
from PIL import Image
12
13
PATH = os.path.abspath("E:/UB CSE/Spring 2018/700/Project/BRATS2013/BRATS_Training/BRATS-2/Image_Data")
14
15
# pad image to standardize size, then crop down as required (by memory constraints)
16
def pad_image(img, desired_shape=(256, 256)):
17
    pad_top = 0
18
    pad_bot = 0
19
    pad_left = 0
20
    pad_right = 0
21
    if desired_shape[0] > img.shape[0]:
22
        pad_top = int((desired_shape[0] - img.shape[0]) / 2)
23
        pad_bot = desired_shape[0] - img.shape[0] - pad_top
24
    if desired_shape[1] > img.shape[1]:
25
        pad_left = int((desired_shape[1] - img.shape[1]) / 2)
26
        pad_right = desired_shape[1] - img.shape[1] - pad_left
27
    img = np.pad(img, ((pad_top, pad_bot), (pad_left, pad_right)), 'constant')
28
    
29
    img = img[50:200,50:200]
30
    img = cv.resize(img, dsize=(28,28), interpolation=cv.INTER_CUBIC)
31
    
32
    return img
33
34
35
def normalize(img):
36
    nimg = None
37
    nimg = cv.normalize(img.astype('float'), nimg, alpha=0.0, beta=1.0, norm_type=cv.NORM_MINMAX)
38
    nimg = pad_image(nimg, desired_shape=(256, 256))
39
    nimg.round(decimals=2)
40
    return nimg
41
42
43
def load_single_image(path):
44
    for dir, subdir, files in os.walk(path):
45
        for file in files:
46
            if file.endswith(".mha"):
47
                img = load_itk(os.path.join(path, file))
48
                return img
49
50
51
def create_1_chan_data(flair, ot):
52
    ot_layers = []
53
    flair_layers = []
54
#     print("OT shape",ot.shape[2])
55
    for layer in range(ot.shape[2]):
56
        ot_layers.append(pad_image(ot[:, :, layer], desired_shape=(256, 256)))
57
#         print("Flair intensities: ", np.unique(flair[:, :, layer]))
58
        normalizedImage = normalize(flair[:, :, layer])
59
#         print("Normalized Image intensities: ", np.unique(normalizedImage))
60
        flair_layers.append(normalizedImage)
61
62
    return np.stack(ot_layers, axis=0), np.stack(flair_layers, axis=0)
63
64
# BRaTS dataset contains 4 channels of input data and one channel of groundtruth for a 3D brain scan image. 
65
def load_dataset(path):
66
    
67
    train_flair = []
68
    train_ot = []
69
70
    for dir in os.listdir(path):
71
        if dir == 'HG':
72
            HG_path = os.path.join(path, 'HG')
73
            for dir2 in os.listdir(HG_path):
74
                if dir2 != '.DS_Store':
75
                    HG_flair = load_single_image(os.path.join(HG_path, dir2, 'VSD.Brain.XX.O.MR_Flair'))
76
                    HG_ot = load_single_image(os.path.join(HG_path, dir2, 'VSD.Brain_3more.XX.XX.OT'))
77
                    assert (HG_ot.shape == HG_flair.shape )
78
                    HG_samples = create_1_chan_data(HG_flair, HG_ot)
79
                    train_ot.append(HG_samples[0])
80
                    train_flair.append(HG_samples[1])
81
82
        if dir == 'LG':
83
            brain_1 = brain_2 = brain_3 = False
84
            LG_path = os.path.join(path, 'LG')
85
            for dir3 in os.listdir(LG_path):
86
                if dir3 != '.DS_Store':
87
                    LG_flair = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain.XX.O.MR_Flair'))
88
                    brain_1 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_1more.XX.XX.OT'))
89
                    brain_2 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_2more.XX.XX.OT'))
90
                    brain_3 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_3more.XX.XX.OT'))
91
                    if brain_1:
92
                        LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_1more.XX.XX.OT'))
93
                    if brain_2:
94
                        LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_2more.XX.XX.OT'))
95
                    if brain_3:
96
                        LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_3more.XX.XX.OT'))
97
98
                    assert (LG_ot.shape == LG_flair.shape)
99
                    LG_samples = create_1_chan_data(LG_flair, LG_ot)
100
                    train_ot.append(LG_samples[0])
101
                    train_flair.append(LG_samples[1])
102
    # Stacking all individual layers
103
    train_ot = np.vstack(train_ot)
104
    train_flair = np.vstack(train_flair)
105
    assert (train_ot.shape == train_flair.shape)
106
    return train_flair,train_ot
107
108
109
# In[2]:
110
111
#SimpleITK is used for reading the brain scan images
112
import SimpleITK as sitk
113
import numpy as np
114
import os
115
import glob
116
from medpy.io import load
117
'''
118
This funciton reads a '.mhd' file using SimpleITK and return the image array, origin and spacing of the image.
119
'''
120
121
def load_itk(filename):
122
    # Reads the image using SimpleITK
123
    itkimage = sitk.ReadImage(filename)
124
125
    # Convert the image to a  numpy array first and then shuffle the dimensions to get axis in the order z,y,x
126
    ct_scan = sitk.GetArrayFromImage(itkimage)
127
128
    # Read the origin of the ct_scan, will be used to convert the coordinates from world to voxel and vice versa.
129
    origin = np.array(list(reversed(itkimage.GetOrigin())))
130
131
    # Read the spacing along each dimension
132
    spacing = np.array(list(reversed(itkimage.GetSpacing())))
133
134
#     return ct_scan, origin, spacing
135
    return ct_scan
136
137
138
# In[3]:
139
140
141
flair_data, ot_data =load_dataset(PATH)
142
143
144
# In[4]:
145
146
147
print(flair_data.shape)
148
149
150
# In[5]:
151
152
153
import matplotlib.pyplot as plt
154
# fig1 = plt.figure()
155
plt.imshow(ot_data[420,:,:])
156
plt.savefig('sample.png')
157
plt.show()
158
159
160
# In[6]:
161
162
163
print(np.unique(ot_data[420,:,:]))
164
165
166
# In[7]:
167
168
169
# imginput = x[0]
170
# imgoutput = x[1]
171
172
173
# In[8]:
174
175
176
print(flair_data.shape)
177
178
179
# In[9]:
180
181
182
print(ot_data.shape)
183
184
185
# In[10]:
186
187
188
np.amax(ot_data)
189
190
191
# # Experiment
192
193
# In[11]:
194
195
196
import os
197
from glob import glob
198
from matplotlib import pyplot
199
from PIL import Image
200
import numpy as np
201
202
203
# Image configuration
204
IMAGE_HEIGHT = 28
205
IMAGE_WIDTH = 28
206
data_files = PATH
207
# shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT,1
208
shape = flair_data.shape[0],flair_data.shape[1],flair_data.shape[2],1
209
print(shape)
210
211
212
# In[12]:
213
214
215
216
def get_batches(batch_size):
217
    """
218
    Generate batches
219
    """
220
#     IMAGE_MAX_VALUE = 255
221
222
223
    current_index = 0
224
    while current_index + batch_size <= shape[0]:
225
226
        data_batch = (ot_data[current_index:current_index + batch_size])
227
        z_batch = (flair_data[current_index:current_index + batch_size])
228
        #print(type(data_batch))
229
        #print(data_batch.shape)
230
        data_batch = data_batch[...,np.newaxis]
231
        #print(data_batch.shape)
232
        
233
234
#         np.vstack((data_batch, x[1,current_index:current_index + batch_size]))
235
        
236
        
237
238
        current_index += batch_size
239
        
240
#         return data_batch / IMAGE_MAX_VALUE - 0.5
241
        
242
#         yield data_batch / IMAGE_MAX_VALUE - 0.5
243
        #print("db:",data_batch.shape)
244
        yield data_batch, z_batch
245
246
247
# In[13]:
248
249
250
print(get_batches(4))
251
252
253
# In[14]:
254
255
256
import tensorflow as tf
257
258
def model_inputs(image_width, image_height, image_channels, z_dim):
259
    """
260
    Create the model inputs
261
    """
262
    inputs_real = tf.placeholder(tf.float32, shape=(None, image_width, image_height, image_channels), name='input_real') 
263
    inputs_z = tf.placeholder(tf.float32, shape=(None,z_dim), name='input_z')
264
    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
265
    
266
    return inputs_real, inputs_z, learning_rate
267
268
269
# In[15]:
270
271
272
def discriminator(images, reuse=False):
273
    """
274
    Create the discriminator network
275
    """
276
    alpha = 0.2
277
    #print("image size:",images.shape)
278
    with tf.variable_scope('discriminator', reuse=reuse):
279
        # using 4 layer network as in DCGAN Paper
280
        
281
        # Conv 1
282
        conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME')
283
        lrelu1 = tf.maximum(alpha * conv1, conv1)
284
#        print("layer1:",lrelu1.shape)
285
        
286
        # Conv 2
287
        conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME')
288
        batch_norm2 = tf.layers.batch_normalization(conv2, training=True)
289
        lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
290
#        print("layer2:",lrelu2.shape)
291
292
        # Conv 3
293
        conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME')
294
        batch_norm3 = tf.layers.batch_normalization(conv3, training=True)
295
        lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
296
#        print("layer3:",lrelu3.shape)
297
298
        # Flatten
299
        flat = tf.reshape(lrelu3, (-1, 1*1*256))
300
#        print("layer4:",flat.shape)
301
        
302
        # Logits
303
        logits = tf.layers.dense(flat, 1)
304
        
305
        # Output
306
        out = tf.sigmoid(logits)
307
        
308
        return out, logits
309
310
311
# In[16]:
312
313
314
def generator(z, out_channel_dim, is_train=True):
315
    """
316
    Create the generator network
317
    """
318
    alpha = 0.2
319
#    print("gen,z:",z.shape)
320
    
321
    with tf.variable_scope('generator', reuse=False if is_train==True else True):
322
               
323
        # using 4 layer network as in DCGAN Paper
324
325
        # First fully connected layer
326
        x_1 = tf.layers.dense(z, 2*2*512)
327
        #print("Gen,fully conn layer 1:",x_1.shape)
328
        
329
        # Reshape it to start the convolutional stack
330
        deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512))
331
        batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train)
332
        lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
333
        #print("Gen,fully conn layer 1 reshape:  ",lrelu2.shape)
334
335
        
336
        # Deconv 1
337
        deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID')
338
        batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train)
339
        lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
340
        #print("Gen,deconv layer 1 : ",lrelu3.shape)
341
342
        
343
        # Deconv 2
344
        deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME')
345
        batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train)
346
        lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4)
347
        #print("Gen,deconv layer 2 : ",lrelu4.shape)
348
349
        # Output layer
350
        logits = tf.layers.conv2d_transpose(lrelu4, out_channel_dim, 5, 2, padding='SAME')
351
        #print("Gen,output layer : ",logits.shape)
352
353
        out = tf.tanh(logits)
354
        
355
        return out
356
357
358
# In[17]:
359
360
361
def model_loss(input_real, input_z, out_channel_dim):
362
    """
363
    Get the loss for the discriminator and generator
364
    """
365
    
366
    label_smoothing = 0.9
367
    
368
    g_model = generator(input_z, out_channel_dim)
369
    d_model_real, d_logits_real = discriminator(input_real)
370
    #print("gmodel size", g_model.shape)
371
    d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)
372
    
373
    
374
#     Change it to norm_l2 loss between generated groundtruth and actual groundtruth
375
    d_loss_real = tf.reduce_mean(
376
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
377
                                                labels=tf.ones_like(d_model_real) * label_smoothing))
378
    d_loss_fake = tf.reduce_mean(
379
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
380
                                                labels=tf.zeros_like(d_model_fake)))
381
    
382
    d_loss = d_loss_real + d_loss_fake
383
                                                  
384
    g_loss = tf.reduce_mean(
385
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
386
                                                labels=tf.ones_like(d_model_fake) * label_smoothing))
387
    
388
    
389
    return d_loss, g_loss
390
391
392
# In[18]:
393
394
395
def model_opt(d_loss, g_loss, learning_rate, beta1):
396
    """
397
    Get optimization operations
398
    """
399
    t_vars = tf.trainable_variables()
400
    d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
401
    g_vars = [var for var in t_vars if var.name.startswith('generator')]
402
403
    # Optimize
404
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 
405
        d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
406
        g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
407
408
    return d_train_opt, g_train_opt
409
410
411
# In[19]:
412
413
414
def show_generator_output(sess, n_images, input_z, out_channel_dim,counter):
415
    """
416
    Show example output for the generator
417
    """
418
#     z_dim = input_z.get_shape().as_list()[-1]
419
#     example_z = np.random.uniform(-1, 1, size=[n_images, z_dim])
420
    example_z = np.reshape(flair_data[420,:,:],(1,IMAGE_WIDTH*IMAGE_HEIGHT))
421
    samples = sess.run(
422
        generator(input_z, out_channel_dim, False),
423
        feed_dict={input_z: example_z})
424
425
    #print("SAmples shape: ", samples.shape)
426
    pyplot.imshow(samples[0,:,:,0])
427
    path = "out"+str(counter)+".png"
428
    pyplot.savefig(path)
429
    pyplot.show()
430
431
432
# In[20]:
433
434
435
def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape):
436
    """
437
    Train the GAN
438
    """
439
    input_real, input_z, _ = model_inputs(data_shape[1], data_shape[2], data_shape[3], z_dim)
440
    d_loss, g_loss = model_loss(input_real, input_z, data_shape[3])
441
    d_opt, g_opt = model_opt(d_loss, g_loss, learning_rate, beta1)
442
    
443
    steps = 0
444
    
445
    with tf.Session() as sess:
446
        sess.run(tf.global_variables_initializer())
447
        for epoch_i in range(epoch_count):
448
            for batch_images,batch_z in get_batches(batch_size):
449
                
450
                # values range from -0.5 to 0.5, therefore scale to range -1, 1
451
#                 batch_images = batch_images * 2
452
                steps += 1
453
                batch_z = np.reshape(batch_z,(batch_size, IMAGE_WIDTH*IMAGE_HEIGHT))
454
#                 batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)
455
                #print("Batch:",batch_images.shape)
456
                #print("Batch Z:",batch_z.shape)
457
458
                _ = sess.run(d_opt, feed_dict={input_real: batch_images, input_z: batch_z})
459
                _ = sess.run(g_opt, feed_dict={input_real: batch_images, input_z: batch_z})
460
                counter = 0
461
                if steps % 400 == 0:
462
                    counter = counter+1
463
                    # At the end of every 10 epochs, get the losses and print them out
464
                    train_loss_d = d_loss.eval({input_z: batch_z, input_real: batch_images})
465
                    train_loss_g = g_loss.eval({input_z: batch_z})
466
467
                    print("Epoch {}/{}...".format(epoch_i+1, epochs),
468
                          "Discriminator Loss: {:.4f}...".format(train_loss_d),
469
                          "Generator Loss: {:.4f}".format(train_loss_g))
470
                    
471
                    _ = show_generator_output(sess, 1, input_z, data_shape[3],(steps/40))
472
473
474
# In[21]:
475
476
477
#### import tensorflow as tf
478
batch_size = 5
479
z_dim = 784
480
learning_rate = 0.0002
481
beta1 = 0.5
482
epochs = 100
483
484
with tf.Graph().as_default():
485
    train(epochs, batch_size, z_dim, learning_rate, beta1, get_batches, shape)
486