--- a
+++ b/BRATS2015.py
@@ -0,0 +1,253 @@
+#%%
+
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import skimage.io as io
+import skimage.transform as trans
+import random as r
+from keras.models import Sequential,load_model,Model,model_from_json
+from keras.layers import Dense, Dropout, Activation, Flatten
+from keras.layers import Convolution2D,concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
+from keras.layers import Input, merge, UpSampling2D
+from keras.callbacks import ModelCheckpoint
+from keras.optimizers import Adam
+from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
+from keras import backend as K
+K.tensorflow_backend._get_available_gpus()
+import SimpleITK as sitk
+#K.set_image_data_format("channels_first")
+K.set_image_dim_ordering("th")
+img_size = 120      #original img size is 240*240
+smooth = 1 
+num_of_aug = 1
+num_epoch = 20
+
+
+#%%
+
+import glob
+def create_data(src, mask, label=False, resize=(155,img_size,img_size)):
+    files = glob.glob(src + mask, recursive=True)
+    imgs = []
+    print('Processing---', mask)
+    for file in files:
+        img = io.imread(file, plugin='simpleitk')
+        img = trans.resize(img, resize, mode='constant')
+        if label:
+            #img[img == 4] = 1       #turn enhancing tumor into necrosis
+            #img[img != 1] = 0       #only left enhancing tumor + necrosis
+            img[img != 0] = 1       #Region 1 => 1+2+3+4 complete tumor
+            img = img.astype('float32')
+        else:
+            img = (img-img.mean()) / img.std()      #normalization => zero mean   !!!care for the std=0 problem
+        for slice in range(50,130):
+            img_t = img[slice,:,:]
+            img_t =img_t.reshape((1,)+img_t.shape)
+            img_t =img_t.reshape((1,)+img_t.shape)   #become rank 4
+            img_g = augmentation(img_t,num_of_aug)
+            for n in range(img_g.shape[0]):
+                imgs.append(img_g[n,:,:,:])
+    name = 'y_'+ str(img_size) if label else 'x_'+ str(img_size)
+    np.save(name, np.array(imgs).astype('float32'))  # save at home
+    print('Saved', len(files), 'to', name)
+
+#%%
+
+def n4itk(img):         #must input with sitk img object
+    img = sitk.Cast(img, sitk.sitkFloat32)
+    img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0))   ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image
+    corrected_img = sitk.N4BiasFieldCorrection(img, img_mask)
+    return corrected_img    
+
+    
+#%%
+
+def augmentation(scans,n):          #input img must be rank 4 
+    datagen = ImageDataGenerator(
+        featurewise_center=False,   
+        samplewise_center=False,  
+        featurewise_std_normalization=False,  
+        samplewise_std_normalization=False,  
+        zca_whitening=False,  
+        rotation_range=25,   
+        #width_shift_range=0.3,  
+        #height_shift_range=0.3,   
+        horizontal_flip=True,   
+        vertical_flip=True,  
+        zoom_range=False)
+    i=0
+    scans_g=scans.copy()
+    for batch in datagen.flow(scans, batch_size=1, seed=1000): 
+        scans_g=np.vstack([scans_g,batch])
+        i += 1
+        if i == n:
+            break
+    '''    remember arg + labels  
+    i=0
+    labels_g=labels.copy()
+    for batch in datagen.flow(labels, batch_size=1, seed=1000): 
+        labels_g=np.vstack([labels_g,batch])
+        i += 1
+        if i > n:
+            break    
+    return ((scans_g,labels_g))'''
+    return scans_g
+#scans_g,labels_g = augmentation(img,img1, 10)
+#X_train = X_train.reshape(X_train.shape[0], 1, img_size, img_size)
+    
+#%%
+
+'''
+Model -
+
+structure:
+
+'''    
+
+def dice_coef(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+
+
+def dice_coef_loss(y_true, y_pred):
+    return -dice_coef(y_true, y_pred)
+    
+    
+def unet_model():
+    inputs = Input((1, img_size, img_size))
+    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)      # KERNEL =3 STRIDE =3
+    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
+    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
+    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
+    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
+    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
+
+    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
+    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)
+
+    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
+    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
+    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)
+
+    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
+    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
+    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)
+
+    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
+    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
+    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)
+
+    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
+    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
+    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
+
+    conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)
+
+    model = Model(input=inputs, output=conv10)
+
+    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
+
+    return model
+    
+
+
+    
+#%%
+# catch all T1c.mha
+create_data('/home/andy/Brain_tumor/BRATS2015/BRATS2015_Training/HGG/', '**/*Flair*.mha', label=False, resize=(155,img_size,img_size))
+create_data('/home/andy/Brain_tumor/BRATS2015/BRATS2015_Training/HGG/', '**/*OT*.mha', label=True, resize=(155,img_size,img_size))
+
+#%%
+# catch BRATS2017 Data
+create_data('/home/andy/Brain_tumor/BRATS2017/Pre-operative_TCGA_GBM_NIfTI_and_Segmentations/', '**/*_flair.nii.gz', label=False, resize=(155,img_size,img_size))
+create_data('/home/andy/Brain_tumor/BRATS2017/Pre-operative_TCGA_GBM_NIfTI_and_Segmentations/', '**/*_GlistrBoost_ManuallyCorrected.nii.gz', label=True, resize=(155,img_size,img_size))
+
+
+#%%
+# load numpy array data
+x = np.load('/home/andy/x_{}.npy'.format(img_size))
+y = np.load('/home/andy/y_{}.npy'.format(img_size))
+
+#%%
+#training
+num = 31100
+
+model = unet_model()
+history = model.fit(x, y, batch_size=16, validation_split=0.2 ,nb_epoch= num_epoch, verbose=1, shuffle=True)
+pred = model.predict(x[num:num+100])
+
+#%%
+# save model and weights
+model.save('aug{}_{}_epoch{}'.format(num_of_aug,img_size,num_epoch))
+model.save_weights('weights_{}_{}.h5'.format(img_size,num_epoch))
+#model.load_weights('weights.h5')
+
+#%%
+# list all data in history
+print(history.history.keys())
+# summarize history for accuracy
+plt.plot(history.history['dice_coef'])
+plt.plot(history.history['val_dice_coef'])
+plt.title('model dice_coef')
+plt.ylabel('dice_coef')
+plt.xlabel('epoch')
+plt.legend(['train', 'validation'], loc='upper left')
+plt.show()
+# summarize history for loss
+plt.plot(history.history['loss'])
+plt.plot(history.history['val_loss'])
+plt.title('model loss')
+plt.ylabel('loss')
+plt.xlabel('epoch')
+plt.legend(['train', 'test'], loc='upper left')
+plt.show()
+
+#%%
+#show results
+for n in range(2):
+    i = int(r.random() * pred.shape[0])
+    plt.figure(figsize=(15,10))
+
+    plt.subplot(131)
+    plt.title('Input'+str(i+num))
+    plt.imshow(x[i+num, 0, :, :],cmap='gray')
+
+    plt.subplot(132)
+    plt.title('Ground Truth')
+    plt.imshow(y[i+num, 0, :, :],cmap='gray')
+
+    plt.subplot(133)
+    plt.title('Prediction')
+    plt.imshow(pred[i, 0, :, :],cmap='gray')
+
+    plt.show()
+
+#%%
+'''
+animation
+'''
+import matplotlib.animation as animation
+def animate(pat, gifname):
+    # Based on @Zombie's code
+    fig = plt.figure()
+    anim = plt.imshow(pat[50])
+    def update(i):
+        anim.set_array(pat[i])
+        return anim,
+    
+    a = animation.FuncAnimation(fig, update, frames=range(len(pat)), interval=50, blit=True)
+    a.save(gifname, writer='imagemagick')
+    
+#animate(pat, 'test.gif')