Switch to side-by-side view

--- a
+++ b/Medical-Image-Segmentation_DCGAN.py
@@ -0,0 +1,486 @@
+
+# coding: utf-8
+
+# In[1]:
+
+
+import os
+from medpy.io import load
+import numpy as np
+import cv2 as cv
+from PIL import Image
+
+PATH = os.path.abspath("E:/UB CSE/Spring 2018/700/Project/BRATS2013/BRATS_Training/BRATS-2/Image_Data")
+
+# pad image to standardize size, then crop down as required (by memory constraints)
+def pad_image(img, desired_shape=(256, 256)):
+    pad_top = 0
+    pad_bot = 0
+    pad_left = 0
+    pad_right = 0
+    if desired_shape[0] > img.shape[0]:
+        pad_top = int((desired_shape[0] - img.shape[0]) / 2)
+        pad_bot = desired_shape[0] - img.shape[0] - pad_top
+    if desired_shape[1] > img.shape[1]:
+        pad_left = int((desired_shape[1] - img.shape[1]) / 2)
+        pad_right = desired_shape[1] - img.shape[1] - pad_left
+    img = np.pad(img, ((pad_top, pad_bot), (pad_left, pad_right)), 'constant')
+    
+    img = img[50:200,50:200]
+    img = cv.resize(img, dsize=(28,28), interpolation=cv.INTER_CUBIC)
+    
+    return img
+
+
+def normalize(img):
+    nimg = None
+    nimg = cv.normalize(img.astype('float'), nimg, alpha=0.0, beta=1.0, norm_type=cv.NORM_MINMAX)
+    nimg = pad_image(nimg, desired_shape=(256, 256))
+    nimg.round(decimals=2)
+    return nimg
+
+
+def load_single_image(path):
+    for dir, subdir, files in os.walk(path):
+        for file in files:
+            if file.endswith(".mha"):
+                img = load_itk(os.path.join(path, file))
+                return img
+
+
+def create_1_chan_data(flair, ot):
+    ot_layers = []
+    flair_layers = []
+#     print("OT shape",ot.shape[2])
+    for layer in range(ot.shape[2]):
+        ot_layers.append(pad_image(ot[:, :, layer], desired_shape=(256, 256)))
+#         print("Flair intensities: ", np.unique(flair[:, :, layer]))
+        normalizedImage = normalize(flair[:, :, layer])
+#         print("Normalized Image intensities: ", np.unique(normalizedImage))
+        flair_layers.append(normalizedImage)
+
+    return np.stack(ot_layers, axis=0), np.stack(flair_layers, axis=0)
+
+# BRaTS dataset contains 4 channels of input data and one channel of groundtruth for a 3D brain scan image. 
+def load_dataset(path):
+    
+    train_flair = []
+    train_ot = []
+
+    for dir in os.listdir(path):
+        if dir == 'HG':
+            HG_path = os.path.join(path, 'HG')
+            for dir2 in os.listdir(HG_path):
+                if dir2 != '.DS_Store':
+                    HG_flair = load_single_image(os.path.join(HG_path, dir2, 'VSD.Brain.XX.O.MR_Flair'))
+                    HG_ot = load_single_image(os.path.join(HG_path, dir2, 'VSD.Brain_3more.XX.XX.OT'))
+                    assert (HG_ot.shape == HG_flair.shape )
+                    HG_samples = create_1_chan_data(HG_flair, HG_ot)
+                    train_ot.append(HG_samples[0])
+                    train_flair.append(HG_samples[1])
+
+        if dir == 'LG':
+            brain_1 = brain_2 = brain_3 = False
+            LG_path = os.path.join(path, 'LG')
+            for dir3 in os.listdir(LG_path):
+                if dir3 != '.DS_Store':
+                    LG_flair = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain.XX.O.MR_Flair'))
+                    brain_1 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_1more.XX.XX.OT'))
+                    brain_2 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_2more.XX.XX.OT'))
+                    brain_3 = os.path.exists(os.path.join(LG_path, dir3, 'VSD.Brain_3more.XX.XX.OT'))
+                    if brain_1:
+                        LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_1more.XX.XX.OT'))
+                    if brain_2:
+                        LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_2more.XX.XX.OT'))
+                    if brain_3:
+                        LG_ot = load_single_image(os.path.join(LG_path, dir3, 'VSD.Brain_3more.XX.XX.OT'))
+
+                    assert (LG_ot.shape == LG_flair.shape)
+                    LG_samples = create_1_chan_data(LG_flair, LG_ot)
+                    train_ot.append(LG_samples[0])
+                    train_flair.append(LG_samples[1])
+    # Stacking all individual layers
+    train_ot = np.vstack(train_ot)
+    train_flair = np.vstack(train_flair)
+    assert (train_ot.shape == train_flair.shape)
+    return train_flair,train_ot
+
+
+# In[2]:
+
+#SimpleITK is used for reading the brain scan images
+import SimpleITK as sitk
+import numpy as np
+import os
+import glob
+from medpy.io import load
+'''
+This funciton reads a '.mhd' file using SimpleITK and return the image array, origin and spacing of the image.
+'''
+
+def load_itk(filename):
+    # Reads the image using SimpleITK
+    itkimage = sitk.ReadImage(filename)
+
+    # Convert the image to a  numpy array first and then shuffle the dimensions to get axis in the order z,y,x
+    ct_scan = sitk.GetArrayFromImage(itkimage)
+
+    # Read the origin of the ct_scan, will be used to convert the coordinates from world to voxel and vice versa.
+    origin = np.array(list(reversed(itkimage.GetOrigin())))
+
+    # Read the spacing along each dimension
+    spacing = np.array(list(reversed(itkimage.GetSpacing())))
+
+#     return ct_scan, origin, spacing
+    return ct_scan
+
+
+# In[3]:
+
+
+flair_data, ot_data =load_dataset(PATH)
+
+
+# In[4]:
+
+
+print(flair_data.shape)
+
+
+# In[5]:
+
+
+import matplotlib.pyplot as plt
+# fig1 = plt.figure()
+plt.imshow(ot_data[420,:,:])
+plt.savefig('sample.png')
+plt.show()
+
+
+# In[6]:
+
+
+print(np.unique(ot_data[420,:,:]))
+
+
+# In[7]:
+
+
+# imginput = x[0]
+# imgoutput = x[1]
+
+
+# In[8]:
+
+
+print(flair_data.shape)
+
+
+# In[9]:
+
+
+print(ot_data.shape)
+
+
+# In[10]:
+
+
+np.amax(ot_data)
+
+
+# # Experiment
+
+# In[11]:
+
+
+import os
+from glob import glob
+from matplotlib import pyplot
+from PIL import Image
+import numpy as np
+
+
+# Image configuration
+IMAGE_HEIGHT = 28
+IMAGE_WIDTH = 28
+data_files = PATH
+# shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT,1
+shape = flair_data.shape[0],flair_data.shape[1],flair_data.shape[2],1
+print(shape)
+
+
+# In[12]:
+
+
+
+def get_batches(batch_size):
+    """
+    Generate batches
+    """
+#     IMAGE_MAX_VALUE = 255
+
+
+    current_index = 0
+    while current_index + batch_size <= shape[0]:
+
+        data_batch = (ot_data[current_index:current_index + batch_size])
+        z_batch = (flair_data[current_index:current_index + batch_size])
+        #print(type(data_batch))
+        #print(data_batch.shape)
+        data_batch = data_batch[...,np.newaxis]
+        #print(data_batch.shape)
+        
+
+#         np.vstack((data_batch, x[1,current_index:current_index + batch_size]))
+        
+        
+
+        current_index += batch_size
+        
+#         return data_batch / IMAGE_MAX_VALUE - 0.5
+        
+#         yield data_batch / IMAGE_MAX_VALUE - 0.5
+        #print("db:",data_batch.shape)
+        yield data_batch, z_batch
+
+
+# In[13]:
+
+
+print(get_batches(4))
+
+
+# In[14]:
+
+
+import tensorflow as tf
+
+def model_inputs(image_width, image_height, image_channels, z_dim):
+    """
+    Create the model inputs
+    """
+    inputs_real = tf.placeholder(tf.float32, shape=(None, image_width, image_height, image_channels), name='input_real') 
+    inputs_z = tf.placeholder(tf.float32, shape=(None,z_dim), name='input_z')
+    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
+    
+    return inputs_real, inputs_z, learning_rate
+
+
+# In[15]:
+
+
+def discriminator(images, reuse=False):
+    """
+    Create the discriminator network
+    """
+    alpha = 0.2
+    #print("image size:",images.shape)
+    with tf.variable_scope('discriminator', reuse=reuse):
+        # using 4 layer network as in DCGAN Paper
+        
+        # Conv 1
+        conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME')
+        lrelu1 = tf.maximum(alpha * conv1, conv1)
+#        print("layer1:",lrelu1.shape)
+        
+        # Conv 2
+        conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME')
+        batch_norm2 = tf.layers.batch_normalization(conv2, training=True)
+        lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
+#        print("layer2:",lrelu2.shape)
+
+        # Conv 3
+        conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME')
+        batch_norm3 = tf.layers.batch_normalization(conv3, training=True)
+        lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
+#        print("layer3:",lrelu3.shape)
+
+        # Flatten
+        flat = tf.reshape(lrelu3, (-1, 1*1*256))
+#        print("layer4:",flat.shape)
+        
+        # Logits
+        logits = tf.layers.dense(flat, 1)
+        
+        # Output
+        out = tf.sigmoid(logits)
+        
+        return out, logits
+
+
+# In[16]:
+
+
+def generator(z, out_channel_dim, is_train=True):
+    """
+    Create the generator network
+    """
+    alpha = 0.2
+#    print("gen,z:",z.shape)
+    
+    with tf.variable_scope('generator', reuse=False if is_train==True else True):
+               
+        # using 4 layer network as in DCGAN Paper
+
+        # First fully connected layer
+        x_1 = tf.layers.dense(z, 2*2*512)
+        #print("Gen,fully conn layer 1:",x_1.shape)
+        
+        # Reshape it to start the convolutional stack
+        deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512))
+        batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train)
+        lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
+        #print("Gen,fully conn layer 1 reshape:  ",lrelu2.shape)
+
+        
+        # Deconv 1
+        deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID')
+        batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train)
+        lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
+        #print("Gen,deconv layer 1 : ",lrelu3.shape)
+
+        
+        # Deconv 2
+        deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME')
+        batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train)
+        lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4)
+        #print("Gen,deconv layer 2 : ",lrelu4.shape)
+
+        # Output layer
+        logits = tf.layers.conv2d_transpose(lrelu4, out_channel_dim, 5, 2, padding='SAME')
+        #print("Gen,output layer : ",logits.shape)
+
+        out = tf.tanh(logits)
+        
+        return out
+
+
+# In[17]:
+
+
+def model_loss(input_real, input_z, out_channel_dim):
+    """
+    Get the loss for the discriminator and generator
+    """
+    
+    label_smoothing = 0.9
+    
+    g_model = generator(input_z, out_channel_dim)
+    d_model_real, d_logits_real = discriminator(input_real)
+    #print("gmodel size", g_model.shape)
+    d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)
+    
+    
+#     Change it to norm_l2 loss between generated groundtruth and actual groundtruth
+    d_loss_real = tf.reduce_mean(
+        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
+                                                labels=tf.ones_like(d_model_real) * label_smoothing))
+    d_loss_fake = tf.reduce_mean(
+        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
+                                                labels=tf.zeros_like(d_model_fake)))
+    
+    d_loss = d_loss_real + d_loss_fake
+                                                  
+    g_loss = tf.reduce_mean(
+        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
+                                                labels=tf.ones_like(d_model_fake) * label_smoothing))
+    
+    
+    return d_loss, g_loss
+
+
+# In[18]:
+
+
+def model_opt(d_loss, g_loss, learning_rate, beta1):
+    """
+    Get optimization operations
+    """
+    t_vars = tf.trainable_variables()
+    d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
+    g_vars = [var for var in t_vars if var.name.startswith('generator')]
+
+    # Optimize
+    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 
+        d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
+        g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
+
+    return d_train_opt, g_train_opt
+
+
+# In[19]:
+
+
+def show_generator_output(sess, n_images, input_z, out_channel_dim,counter):
+    """
+    Show example output for the generator
+    """
+#     z_dim = input_z.get_shape().as_list()[-1]
+#     example_z = np.random.uniform(-1, 1, size=[n_images, z_dim])
+    example_z = np.reshape(flair_data[420,:,:],(1,IMAGE_WIDTH*IMAGE_HEIGHT))
+    samples = sess.run(
+        generator(input_z, out_channel_dim, False),
+        feed_dict={input_z: example_z})
+
+    #print("SAmples shape: ", samples.shape)
+    pyplot.imshow(samples[0,:,:,0])
+    path = "out"+str(counter)+".png"
+    pyplot.savefig(path)
+    pyplot.show()
+
+
+# In[20]:
+
+
+def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape):
+    """
+    Train the GAN
+    """
+    input_real, input_z, _ = model_inputs(data_shape[1], data_shape[2], data_shape[3], z_dim)
+    d_loss, g_loss = model_loss(input_real, input_z, data_shape[3])
+    d_opt, g_opt = model_opt(d_loss, g_loss, learning_rate, beta1)
+    
+    steps = 0
+    
+    with tf.Session() as sess:
+        sess.run(tf.global_variables_initializer())
+        for epoch_i in range(epoch_count):
+            for batch_images,batch_z in get_batches(batch_size):
+                
+                # values range from -0.5 to 0.5, therefore scale to range -1, 1
+#                 batch_images = batch_images * 2
+                steps += 1
+                batch_z = np.reshape(batch_z,(batch_size, IMAGE_WIDTH*IMAGE_HEIGHT))
+#                 batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)
+                #print("Batch:",batch_images.shape)
+                #print("Batch Z:",batch_z.shape)
+
+                _ = sess.run(d_opt, feed_dict={input_real: batch_images, input_z: batch_z})
+                _ = sess.run(g_opt, feed_dict={input_real: batch_images, input_z: batch_z})
+                counter = 0
+                if steps % 400 == 0:
+                    counter = counter+1
+                    # At the end of every 10 epochs, get the losses and print them out
+                    train_loss_d = d_loss.eval({input_z: batch_z, input_real: batch_images})
+                    train_loss_g = g_loss.eval({input_z: batch_z})
+
+                    print("Epoch {}/{}...".format(epoch_i+1, epochs),
+                          "Discriminator Loss: {:.4f}...".format(train_loss_d),
+                          "Generator Loss: {:.4f}".format(train_loss_g))
+                    
+                    _ = show_generator_output(sess, 1, input_z, data_shape[3],(steps/40))
+
+
+# In[21]:
+
+
+#### import tensorflow as tf
+batch_size = 5
+z_dim = 784
+learning_rate = 0.0002
+beta1 = 0.5
+epochs = 100
+
+with tf.Graph().as_default():
+    train(epochs, batch_size, z_dim, learning_rate, beta1, get_batches, shape)
+