Diff of /BRATS2015.py [000000] .. [baebdc]

Switch to unified view

a b/BRATS2015.py
1
#%%
2
3
import numpy as np
4
import pandas as pd
5
import matplotlib.pyplot as plt
6
import skimage.io as io
7
import skimage.transform as trans
8
import random as r
9
from keras.models import Sequential,load_model,Model,model_from_json
10
from keras.layers import Dense, Dropout, Activation, Flatten
11
from keras.layers import Convolution2D,concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
12
from keras.layers import Input, merge, UpSampling2D
13
from keras.callbacks import ModelCheckpoint
14
from keras.optimizers import Adam
15
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
16
from keras import backend as K
17
K.tensorflow_backend._get_available_gpus()
18
import SimpleITK as sitk
19
#K.set_image_data_format("channels_first")
20
K.set_image_dim_ordering("th")
21
img_size = 120      #original img size is 240*240
22
smooth = 1 
23
num_of_aug = 1
24
num_epoch = 20
25
26
27
#%%
28
29
import glob
30
def create_data(src, mask, label=False, resize=(155,img_size,img_size)):
31
    files = glob.glob(src + mask, recursive=True)
32
    imgs = []
33
    print('Processing---', mask)
34
    for file in files:
35
        img = io.imread(file, plugin='simpleitk')
36
        img = trans.resize(img, resize, mode='constant')
37
        if label:
38
            #img[img == 4] = 1       #turn enhancing tumor into necrosis
39
            #img[img != 1] = 0       #only left enhancing tumor + necrosis
40
            img[img != 0] = 1       #Region 1 => 1+2+3+4 complete tumor
41
            img = img.astype('float32')
42
        else:
43
            img = (img-img.mean()) / img.std()      #normalization => zero mean   !!!care for the std=0 problem
44
        for slice in range(50,130):
45
            img_t = img[slice,:,:]
46
            img_t =img_t.reshape((1,)+img_t.shape)
47
            img_t =img_t.reshape((1,)+img_t.shape)   #become rank 4
48
            img_g = augmentation(img_t,num_of_aug)
49
            for n in range(img_g.shape[0]):
50
                imgs.append(img_g[n,:,:,:])
51
    name = 'y_'+ str(img_size) if label else 'x_'+ str(img_size)
52
    np.save(name, np.array(imgs).astype('float32'))  # save at home
53
    print('Saved', len(files), 'to', name)
54
55
#%%
56
57
def n4itk(img):         #must input with sitk img object
58
    img = sitk.Cast(img, sitk.sitkFloat32)
59
    img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0))   ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image
60
    corrected_img = sitk.N4BiasFieldCorrection(img, img_mask)
61
    return corrected_img    
62
63
    
64
#%%
65
66
def augmentation(scans,n):          #input img must be rank 4 
67
    datagen = ImageDataGenerator(
68
        featurewise_center=False,   
69
        samplewise_center=False,  
70
        featurewise_std_normalization=False,  
71
        samplewise_std_normalization=False,  
72
        zca_whitening=False,  
73
        rotation_range=25,   
74
        #width_shift_range=0.3,  
75
        #height_shift_range=0.3,   
76
        horizontal_flip=True,   
77
        vertical_flip=True,  
78
        zoom_range=False)
79
    i=0
80
    scans_g=scans.copy()
81
    for batch in datagen.flow(scans, batch_size=1, seed=1000): 
82
        scans_g=np.vstack([scans_g,batch])
83
        i += 1
84
        if i == n:
85
            break
86
    '''    remember arg + labels  
87
    i=0
88
    labels_g=labels.copy()
89
    for batch in datagen.flow(labels, batch_size=1, seed=1000): 
90
        labels_g=np.vstack([labels_g,batch])
91
        i += 1
92
        if i > n:
93
            break    
94
    return ((scans_g,labels_g))'''
95
    return scans_g
96
#scans_g,labels_g = augmentation(img,img1, 10)
97
#X_train = X_train.reshape(X_train.shape[0], 1, img_size, img_size)
98
    
99
#%%
100
101
'''
102
Model -
103
104
structure:
105
106
'''    
107
108
def dice_coef(y_true, y_pred):
109
    y_true_f = K.flatten(y_true)
110
    y_pred_f = K.flatten(y_pred)
111
    intersection = K.sum(y_true_f * y_pred_f)
112
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
113
114
115
def dice_coef_loss(y_true, y_pred):
116
    return -dice_coef(y_true, y_pred)
117
    
118
    
119
def unet_model():
120
    inputs = Input((1, img_size, img_size))
121
    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)      # KERNEL =3 STRIDE =3
122
    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
123
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
124
125
    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
126
    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
127
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
128
129
    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
130
    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
131
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
132
133
    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
134
    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
135
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
136
137
    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
138
    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)
139
140
    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
141
    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
142
    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)
143
144
    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
145
    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
146
    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)
147
148
    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
149
    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
150
    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)
151
152
    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
153
    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
154
    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
155
156
    conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)
157
158
    model = Model(input=inputs, output=conv10)
159
160
    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
161
162
    return model
163
    
164
165
166
    
167
#%%
168
# catch all T1c.mha
169
create_data('/home/andy/Brain_tumor/BRATS2015/BRATS2015_Training/HGG/', '**/*Flair*.mha', label=False, resize=(155,img_size,img_size))
170
create_data('/home/andy/Brain_tumor/BRATS2015/BRATS2015_Training/HGG/', '**/*OT*.mha', label=True, resize=(155,img_size,img_size))
171
172
#%%
173
# catch BRATS2017 Data
174
create_data('/home/andy/Brain_tumor/BRATS2017/Pre-operative_TCGA_GBM_NIfTI_and_Segmentations/', '**/*_flair.nii.gz', label=False, resize=(155,img_size,img_size))
175
create_data('/home/andy/Brain_tumor/BRATS2017/Pre-operative_TCGA_GBM_NIfTI_and_Segmentations/', '**/*_GlistrBoost_ManuallyCorrected.nii.gz', label=True, resize=(155,img_size,img_size))
176
177
178
#%%
179
# load numpy array data
180
x = np.load('/home/andy/x_{}.npy'.format(img_size))
181
y = np.load('/home/andy/y_{}.npy'.format(img_size))
182
183
#%%
184
#training
185
num = 31100
186
187
model = unet_model()
188
history = model.fit(x, y, batch_size=16, validation_split=0.2 ,nb_epoch= num_epoch, verbose=1, shuffle=True)
189
pred = model.predict(x[num:num+100])
190
191
#%%
192
# save model and weights
193
model.save('aug{}_{}_epoch{}'.format(num_of_aug,img_size,num_epoch))
194
model.save_weights('weights_{}_{}.h5'.format(img_size,num_epoch))
195
#model.load_weights('weights.h5')
196
197
#%%
198
# list all data in history
199
print(history.history.keys())
200
# summarize history for accuracy
201
plt.plot(history.history['dice_coef'])
202
plt.plot(history.history['val_dice_coef'])
203
plt.title('model dice_coef')
204
plt.ylabel('dice_coef')
205
plt.xlabel('epoch')
206
plt.legend(['train', 'validation'], loc='upper left')
207
plt.show()
208
# summarize history for loss
209
plt.plot(history.history['loss'])
210
plt.plot(history.history['val_loss'])
211
plt.title('model loss')
212
plt.ylabel('loss')
213
plt.xlabel('epoch')
214
plt.legend(['train', 'test'], loc='upper left')
215
plt.show()
216
217
#%%
218
#show results
219
for n in range(2):
220
    i = int(r.random() * pred.shape[0])
221
    plt.figure(figsize=(15,10))
222
223
    plt.subplot(131)
224
    plt.title('Input'+str(i+num))
225
    plt.imshow(x[i+num, 0, :, :],cmap='gray')
226
227
    plt.subplot(132)
228
    plt.title('Ground Truth')
229
    plt.imshow(y[i+num, 0, :, :],cmap='gray')
230
231
    plt.subplot(133)
232
    plt.title('Prediction')
233
    plt.imshow(pred[i, 0, :, :],cmap='gray')
234
235
    plt.show()
236
237
#%%
238
'''
239
animation
240
'''
241
import matplotlib.animation as animation
242
def animate(pat, gifname):
243
    # Based on @Zombie's code
244
    fig = plt.figure()
245
    anim = plt.imshow(pat[50])
246
    def update(i):
247
        anim.set_array(pat[i])
248
        return anim,
249
    
250
    a = animation.FuncAnimation(fig, update, frames=range(len(pat)), interval=50, blit=True)
251
    a.save(gifname, writer='imagemagick')
252
    
253
#animate(pat, 'test.gif')