|
a |
|
b/BRATS2015.py |
|
|
1 |
#%% |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
import skimage.io as io |
|
|
7 |
import skimage.transform as trans |
|
|
8 |
import random as r |
|
|
9 |
from keras.models import Sequential,load_model,Model,model_from_json |
|
|
10 |
from keras.layers import Dense, Dropout, Activation, Flatten |
|
|
11 |
from keras.layers import Convolution2D,concatenate, Conv2D, MaxPooling2D, Conv2DTranspose |
|
|
12 |
from keras.layers import Input, merge, UpSampling2D |
|
|
13 |
from keras.callbacks import ModelCheckpoint |
|
|
14 |
from keras.optimizers import Adam |
|
|
15 |
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img |
|
|
16 |
from keras import backend as K |
|
|
17 |
K.tensorflow_backend._get_available_gpus() |
|
|
18 |
import SimpleITK as sitk |
|
|
19 |
#K.set_image_data_format("channels_first") |
|
|
20 |
K.set_image_dim_ordering("th") |
|
|
21 |
img_size = 120 #original img size is 240*240 |
|
|
22 |
smooth = 1 |
|
|
23 |
num_of_aug = 1 |
|
|
24 |
num_epoch = 20 |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
#%% |
|
|
28 |
|
|
|
29 |
import glob |
|
|
30 |
def create_data(src, mask, label=False, resize=(155,img_size,img_size)): |
|
|
31 |
files = glob.glob(src + mask, recursive=True) |
|
|
32 |
imgs = [] |
|
|
33 |
print('Processing---', mask) |
|
|
34 |
for file in files: |
|
|
35 |
img = io.imread(file, plugin='simpleitk') |
|
|
36 |
img = trans.resize(img, resize, mode='constant') |
|
|
37 |
if label: |
|
|
38 |
#img[img == 4] = 1 #turn enhancing tumor into necrosis |
|
|
39 |
#img[img != 1] = 0 #only left enhancing tumor + necrosis |
|
|
40 |
img[img != 0] = 1 #Region 1 => 1+2+3+4 complete tumor |
|
|
41 |
img = img.astype('float32') |
|
|
42 |
else: |
|
|
43 |
img = (img-img.mean()) / img.std() #normalization => zero mean !!!care for the std=0 problem |
|
|
44 |
for slice in range(50,130): |
|
|
45 |
img_t = img[slice,:,:] |
|
|
46 |
img_t =img_t.reshape((1,)+img_t.shape) |
|
|
47 |
img_t =img_t.reshape((1,)+img_t.shape) #become rank 4 |
|
|
48 |
img_g = augmentation(img_t,num_of_aug) |
|
|
49 |
for n in range(img_g.shape[0]): |
|
|
50 |
imgs.append(img_g[n,:,:,:]) |
|
|
51 |
name = 'y_'+ str(img_size) if label else 'x_'+ str(img_size) |
|
|
52 |
np.save(name, np.array(imgs).astype('float32')) # save at home |
|
|
53 |
print('Saved', len(files), 'to', name) |
|
|
54 |
|
|
|
55 |
#%% |
|
|
56 |
|
|
|
57 |
def n4itk(img): #must input with sitk img object |
|
|
58 |
img = sitk.Cast(img, sitk.sitkFloat32) |
|
|
59 |
img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0)) ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image |
|
|
60 |
corrected_img = sitk.N4BiasFieldCorrection(img, img_mask) |
|
|
61 |
return corrected_img |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
#%% |
|
|
65 |
|
|
|
66 |
def augmentation(scans,n): #input img must be rank 4 |
|
|
67 |
datagen = ImageDataGenerator( |
|
|
68 |
featurewise_center=False, |
|
|
69 |
samplewise_center=False, |
|
|
70 |
featurewise_std_normalization=False, |
|
|
71 |
samplewise_std_normalization=False, |
|
|
72 |
zca_whitening=False, |
|
|
73 |
rotation_range=25, |
|
|
74 |
#width_shift_range=0.3, |
|
|
75 |
#height_shift_range=0.3, |
|
|
76 |
horizontal_flip=True, |
|
|
77 |
vertical_flip=True, |
|
|
78 |
zoom_range=False) |
|
|
79 |
i=0 |
|
|
80 |
scans_g=scans.copy() |
|
|
81 |
for batch in datagen.flow(scans, batch_size=1, seed=1000): |
|
|
82 |
scans_g=np.vstack([scans_g,batch]) |
|
|
83 |
i += 1 |
|
|
84 |
if i == n: |
|
|
85 |
break |
|
|
86 |
''' remember arg + labels |
|
|
87 |
i=0 |
|
|
88 |
labels_g=labels.copy() |
|
|
89 |
for batch in datagen.flow(labels, batch_size=1, seed=1000): |
|
|
90 |
labels_g=np.vstack([labels_g,batch]) |
|
|
91 |
i += 1 |
|
|
92 |
if i > n: |
|
|
93 |
break |
|
|
94 |
return ((scans_g,labels_g))''' |
|
|
95 |
return scans_g |
|
|
96 |
#scans_g,labels_g = augmentation(img,img1, 10) |
|
|
97 |
#X_train = X_train.reshape(X_train.shape[0], 1, img_size, img_size) |
|
|
98 |
|
|
|
99 |
#%% |
|
|
100 |
|
|
|
101 |
''' |
|
|
102 |
Model - |
|
|
103 |
|
|
|
104 |
structure: |
|
|
105 |
|
|
|
106 |
''' |
|
|
107 |
|
|
|
108 |
def dice_coef(y_true, y_pred): |
|
|
109 |
y_true_f = K.flatten(y_true) |
|
|
110 |
y_pred_f = K.flatten(y_pred) |
|
|
111 |
intersection = K.sum(y_true_f * y_pred_f) |
|
|
112 |
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
def dice_coef_loss(y_true, y_pred): |
|
|
116 |
return -dice_coef(y_true, y_pred) |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
def unet_model(): |
|
|
120 |
inputs = Input((1, img_size, img_size)) |
|
|
121 |
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs) # KERNEL =3 STRIDE =3 |
|
|
122 |
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1) |
|
|
123 |
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) |
|
|
124 |
|
|
|
125 |
conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1) |
|
|
126 |
conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2) |
|
|
127 |
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) |
|
|
128 |
|
|
|
129 |
conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2) |
|
|
130 |
conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3) |
|
|
131 |
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) |
|
|
132 |
|
|
|
133 |
conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3) |
|
|
134 |
conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4) |
|
|
135 |
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) |
|
|
136 |
|
|
|
137 |
conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4) |
|
|
138 |
conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5) |
|
|
139 |
|
|
|
140 |
up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1) |
|
|
141 |
conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6) |
|
|
142 |
conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6) |
|
|
143 |
|
|
|
144 |
up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1) |
|
|
145 |
conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7) |
|
|
146 |
conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7) |
|
|
147 |
|
|
|
148 |
up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1) |
|
|
149 |
conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8) |
|
|
150 |
conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8) |
|
|
151 |
|
|
|
152 |
up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1) |
|
|
153 |
conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9) |
|
|
154 |
conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9) |
|
|
155 |
|
|
|
156 |
conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9) |
|
|
157 |
|
|
|
158 |
model = Model(input=inputs, output=conv10) |
|
|
159 |
|
|
|
160 |
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef]) |
|
|
161 |
|
|
|
162 |
return model |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
|
|
|
166 |
|
|
|
167 |
#%% |
|
|
168 |
# catch all T1c.mha |
|
|
169 |
create_data('/home/andy/Brain_tumor/BRATS2015/BRATS2015_Training/HGG/', '**/*Flair*.mha', label=False, resize=(155,img_size,img_size)) |
|
|
170 |
create_data('/home/andy/Brain_tumor/BRATS2015/BRATS2015_Training/HGG/', '**/*OT*.mha', label=True, resize=(155,img_size,img_size)) |
|
|
171 |
|
|
|
172 |
#%% |
|
|
173 |
# catch BRATS2017 Data |
|
|
174 |
create_data('/home/andy/Brain_tumor/BRATS2017/Pre-operative_TCGA_GBM_NIfTI_and_Segmentations/', '**/*_flair.nii.gz', label=False, resize=(155,img_size,img_size)) |
|
|
175 |
create_data('/home/andy/Brain_tumor/BRATS2017/Pre-operative_TCGA_GBM_NIfTI_and_Segmentations/', '**/*_GlistrBoost_ManuallyCorrected.nii.gz', label=True, resize=(155,img_size,img_size)) |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
#%% |
|
|
179 |
# load numpy array data |
|
|
180 |
x = np.load('/home/andy/x_{}.npy'.format(img_size)) |
|
|
181 |
y = np.load('/home/andy/y_{}.npy'.format(img_size)) |
|
|
182 |
|
|
|
183 |
#%% |
|
|
184 |
#training |
|
|
185 |
num = 31100 |
|
|
186 |
|
|
|
187 |
model = unet_model() |
|
|
188 |
history = model.fit(x, y, batch_size=16, validation_split=0.2 ,nb_epoch= num_epoch, verbose=1, shuffle=True) |
|
|
189 |
pred = model.predict(x[num:num+100]) |
|
|
190 |
|
|
|
191 |
#%% |
|
|
192 |
# save model and weights |
|
|
193 |
model.save('aug{}_{}_epoch{}'.format(num_of_aug,img_size,num_epoch)) |
|
|
194 |
model.save_weights('weights_{}_{}.h5'.format(img_size,num_epoch)) |
|
|
195 |
#model.load_weights('weights.h5') |
|
|
196 |
|
|
|
197 |
#%% |
|
|
198 |
# list all data in history |
|
|
199 |
print(history.history.keys()) |
|
|
200 |
# summarize history for accuracy |
|
|
201 |
plt.plot(history.history['dice_coef']) |
|
|
202 |
plt.plot(history.history['val_dice_coef']) |
|
|
203 |
plt.title('model dice_coef') |
|
|
204 |
plt.ylabel('dice_coef') |
|
|
205 |
plt.xlabel('epoch') |
|
|
206 |
plt.legend(['train', 'validation'], loc='upper left') |
|
|
207 |
plt.show() |
|
|
208 |
# summarize history for loss |
|
|
209 |
plt.plot(history.history['loss']) |
|
|
210 |
plt.plot(history.history['val_loss']) |
|
|
211 |
plt.title('model loss') |
|
|
212 |
plt.ylabel('loss') |
|
|
213 |
plt.xlabel('epoch') |
|
|
214 |
plt.legend(['train', 'test'], loc='upper left') |
|
|
215 |
plt.show() |
|
|
216 |
|
|
|
217 |
#%% |
|
|
218 |
#show results |
|
|
219 |
for n in range(2): |
|
|
220 |
i = int(r.random() * pred.shape[0]) |
|
|
221 |
plt.figure(figsize=(15,10)) |
|
|
222 |
|
|
|
223 |
plt.subplot(131) |
|
|
224 |
plt.title('Input'+str(i+num)) |
|
|
225 |
plt.imshow(x[i+num, 0, :, :],cmap='gray') |
|
|
226 |
|
|
|
227 |
plt.subplot(132) |
|
|
228 |
plt.title('Ground Truth') |
|
|
229 |
plt.imshow(y[i+num, 0, :, :],cmap='gray') |
|
|
230 |
|
|
|
231 |
plt.subplot(133) |
|
|
232 |
plt.title('Prediction') |
|
|
233 |
plt.imshow(pred[i, 0, :, :],cmap='gray') |
|
|
234 |
|
|
|
235 |
plt.show() |
|
|
236 |
|
|
|
237 |
#%% |
|
|
238 |
''' |
|
|
239 |
animation |
|
|
240 |
''' |
|
|
241 |
import matplotlib.animation as animation |
|
|
242 |
def animate(pat, gifname): |
|
|
243 |
# Based on @Zombie's code |
|
|
244 |
fig = plt.figure() |
|
|
245 |
anim = plt.imshow(pat[50]) |
|
|
246 |
def update(i): |
|
|
247 |
anim.set_array(pat[i]) |
|
|
248 |
return anim, |
|
|
249 |
|
|
|
250 |
a = animation.FuncAnimation(fig, update, frames=range(len(pat)), interval=50, blit=True) |
|
|
251 |
a.save(gifname, writer='imagemagick') |
|
|
252 |
|
|
|
253 |
#animate(pat, 'test.gif') |