Diff of /ConvNet_driver.py [000000] .. [271336]

Switch to unified view

a b/ConvNet_driver.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Wed Nov 02 21:35:59 2016
4
5
@author: seeker105
6
"""
7
import os.path
8
import sys
9
from ConvNet import LeNet
10
import json
11
import SimpleITK as sitk
12
import pylab
13
from skimage import color
14
from sklearn.utils import shuffle
15
from scipy.ndimage.interpolation import rotate
16
from keras.optimizers import SGD
17
from keras.utils import np_utils
18
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
19
import numpy as np
20
from Brain_pipeline import Pipeline
21
import Brain_pipeline
22
import Metrics
23
from glob import glob
24
import model_test
25
26
''' Script to drive loading, training, testing and saving the brain MRI
27
    First, we load all the images, and process them through the Pipeline,
28
    and get the pre-processed images as output.
29
    Then, we train the model from ConvNet or load the weights into it.
30
    We divide the training and test data using train_test_split.
31
    '''
32
def show_segmented_image(orig_img, pred_img):
33
    '''
34
    Show the prediction over the original image
35
    INPUT:
36
        1)orig_img: the test image, which was used as input
37
        2)pred_img: the prediction output
38
    OUTPUT:
39
        segmented image rendering
40
    '''
41
    #define the colours of the labels
42
    red = [10, 0, 0] #label 1
43
    yellow = [10, 10, 0] #label 2
44
    green = [0, 10, 0]  #label 3
45
    blue = [0, 0, 10] #label 4
46
    #convert original image to rgb
47
    gray_im = color.gray2rgb(orig_img)
48
    #color the tumor voxels
49
    gray_im[pred_img == 1] = red
50
    gray_im[pred_img == 2] = yellow
51
    gray_im[pred_img == 3] = green
52
    gray_im[pred_img == 4] = blue
53
    pylab.imshow(gray_im)
54
55
56
def step_decay(epochs):
57
    init_rate = 0.003
58
    fin_rate = 0.00003
59
    total_epochs = 24
60
    print 'ep: {}'.format(epochs)
61
    if epochs<25:
62
        lrate = init_rate - (init_rate - fin_rate)/total_epochs * float(epochs)
63
    else: lrate = 0.00003
64
    print 'lrate: {}'.format(model.optimizer.lr.get_value())
65
    return lrate
66
67
pth_train = 'D:/New folder/BRATS2015_Training/train_slices/'
68
pth_test = 'D:/New folder/BRATS2015_Training/test_slices/'
69
x = Pipeline(pth_train, pth_test)   #pass the images through the preprocessing steps
70
71
#build the model
72
model = LeNet.build_Pereira(33, 33, 4, 5)   
73
74
#callback
75
change_lr = LearningRateScheduler(step_decay)
76
77
#initialize the optimizer and model
78
opt = SGD(lr = 0.003, momentum=0.9, decay= 0, nesterov = True)
79
model.compile(loss = 'categorical_crossentropy', optimizer=opt, metrics = ['accuracy'])
80
81
82
#load training patches
83
X_patches, Y_labels, mu, sigma = x.training_patches([180000, 67500, 67500, 67500, 67500])
84
tmp = rotate(X_patches, 90, (2, 3))
85
tmp = np.append(tmp, rotate(X_patches, -90, (2, 3)), axis=0)
86
tmp = np.append(tmp, rotate(X_patches, 180, (2, 3)), axis=0)
87
X_patches = np.append(X_patches, tmp, axis=0)
88
Y_labels = np.hstack(Y_labels)
89
for i in xrange(2):
90
    Y_labels = np.append(Y_labels, Y_labels, axis=0)
91
# Labels should be in categorical array form 1x5
92
Y_labels = np_utils.to_categorical(Y_labels, 5)
93
X_patches, Y_labels = shuffle(X_patches, Y_labels, random_state=0)
94
95
#save model after each epoch
96
os.mkdir(r'D:\New folder\Pereira_model_checkpoints')
97
checkpointer = ModelCheckpoint(filepath='D:/New folder/Pereira_model_checkpoints/weights.{epoch:02d}-{val_loss:.2f}.keras2.hdf5',monitor = 'val_loss', verbose=1)
98
#fit model and shuffle training data
99
hist = model.fit(X_patches[:200000], Y_labels[:200000], nb_epoch=25, batch_size=128, verbose=1, validation_split=0.1, callbacks = [change_lr, checkpointer])
100
 
101
#save model
102
sv_pth = 'D:/New Folder/Pereira_model_checkpoints/model_weights'
103
m = '{}.json'.format(sv_pth)
104
w = '{}.hdf5'.format(sv_pth)
105
model.save_weights(w)
106
json_strng = model.to_json()
107
with open(m, 'w') as f:
108
    json.dump(json_strng, f)
109
    
110
111
#test all the test image slices
112
test_im = x.test_im.swapaxes(0,1)
113
gt = test_im[4]
114
test_im = test_im[:4].swapaxes(0, 1)
115
predicted_images, params = model_test.test_slices(test_im[158:159], gt[158:159], model, mu, sigma)
116
117
'''test_pths = zip(*x.pathnames_test)
118
#show a segmented slice
119
tst = test_pths[0]#random.choice(test_pths)
120
test_arr = [sitk.GetArrayFromImage(sitk.ReadImage(i)) for i in tst]
121
final_pth = os.path.dirname(os.path.dirname(tst[0])) +  '/' + os.path.splitext(os.path.splitext(os.path.basename(tst[0]))[0])[0] + '_processed_predicted_70.mha'  
122
slice_arr = [test_arr[j][70] for j in xrange(4)]
123
patches = Brain_pipeline.test_patches(slice_arr)
124
pred = model.predict_classes(patches)
125
pred = Brain_pipeline.reconstruct_labels(pred)
126
show_segmented_image(test_arr[0][70], pred)
127
sitk.WriteImage(sitk.GetImageFromArray(np.array(pred.astype(float))), final_pth)
128
129
130
#evaluate metrics
131
DSC_arr = [] #stores DSC
132
DSC_core_arr = [] #stores list of core DSCs
133
PPV_arr = []
134
acc_arr = []
135
136
#use for getting orignal brain image and prediction label slices
137
# use for:
138
    #overlay images
139
    #segmentation vs orig label. it's in test_paths
140
    #with/without nyul
141
    #4 sequences after nyul. for original ones, redefine paths
142
    #ok. now we gotta see metrics brother
143
pred_pth = []
144
t1c_pth = []
145
pred_arr = []
146
147
for i in xrange(len(test_pths)):
148
    tst = test_pths[i]
149
    test_arr = [sitk.GetArrayFromImage(sitk.ReadImage(j)) for j in tst]
150
    #take slices
151
    slice_arr = [test_arr[j][70] for j in xrange(4)]
152
    #read original slice label
153
    orig = test_arr[4][70]
154
    patches = Brain_pipeline.test_patches(slice_arr)
155
    pred = model.predict_classes(patches)
156
    pred = Brain_pipeline.reconstruct_labels(pred)
157
    acc_arr.append(Metrics.accuracy(pred, orig))
158
    DSC_arr.append(Metrics.DSC(pred, orig, 2))
159
    DSC_core_arr.append(Metrics.DSC_core_tumor(pred, orig))
160
    PPV_arr.append(Metrics.PPV(pred, orig))
161
    print 'acc: {}'.format(acc_arr[i])
162
    print 'DSC: {}'.format(DSC_arr[i])
163
    print 'DSC_core: {}'.format(DSC_core_arr[i])
164
    print 'PPV : {}'.format(PPV_arr[i])
165
    sys.stdout.flush()
166
    final_pth = os.path.dirname(tst[4]) +  '/' + os.path.splitext(os.path.basename(tst[0]))[0] + '_predicted_70.mha'  
167
    pred_pth.append(final_pth)
168
    pred_arr.append(pred)
169
    t1c_pth.append([flp for flp in glob(os.path.dirname(tst[2]) + '/*.mha') if 'n4' not in flp])
170
               '''