Diff of /demo.py [000000] .. [e44b03]

Switch to unified view

a b/demo.py
1
import numpy as np
2
import keras
3
import argparse
4
import os
5
import tf_models
6
import tensorflow as tf
7
from keras.models import Sequential, Model
8
from keras.layers import Dense, Conv3D, Dropout, Flatten, Input, concatenate, Reshape, Lambda, Permute
9
from keras.layers.core import Dense, Dropout, Activation, Reshape
10
from keras.layers.convolutional import Conv3D, Conv3DTranspose, UpSampling3D
11
from keras.layers.pooling import AveragePooling3D
12
from keras.layers import Input
13
from keras.layers.merge import concatenate
14
from keras.layers.normalization import BatchNormalization
15
from tensorflow.contrib.keras.python.keras.backend import learning_phase
16
17
from nibabel import load as load_nii
18
from sklearn.preprocessing import scale
19
import matplotlib.pyplot as plt
20
21
22
23
# SAVE_PATH = 'unet3d_baseline.hdf5'
24
# OFFSET_W = 16
25
# OFFSET_H = 16
26
# OFFSET_C = 4
27
# HSIZE = 64
28
# WSIZE = 64
29
# CSIZE = 16
30
# batches_h, batches_w, batches_c = (224-HSIZE)/OFFSET_H+1, (224-WSIZE)/OFFSET_W+1, (152 - CSIZE)/OFFSET_C+1
31
32
33
def parse_inputs():
34
    parser = argparse.ArgumentParser(description='Test different nets with 3D data.')
35
    parser.add_argument('-r', '--root-path', dest='root_path', default='/mnt/disk1/dat/lchen63/brain/data/data2')
36
    parser.add_argument('-m', '--model-path', dest='model_path',
37
                        default='NoneDense-0')
38
    parser.add_argument('-ow', '--offset-width', dest='offset_w', type=int, default=12)
39
    parser.add_argument('-oh', '--offset-height', dest='offset_h', type=int, default=12)
40
    parser.add_argument('-oc', '--offset-channel', dest='offset_c', nargs='+', type=int, default=12)
41
    parser.add_argument('-ws', '--width-size', dest='wsize', type=int, default=38)
42
    parser.add_argument('-hs', '--height-size', dest='hsize', type=int, default=38)
43
    parser.add_argument('-cs', '--channel-size', dest='csize', type=int, default=38)
44
    parser.add_argument('-ps', '--pred-size', dest='psize', type=int, default=12)
45
    parser.add_argument('-gpu', '--gpu', dest='gpu', type=str, default='0')
46
    parser.add_argument('-mn', '--model_name', dest='model_name', type=str, default='dense24')
47
    parser.add_argument('-nc', '--correction', dest='correction', type=bool, default=True)
48
    parser.add_argument('-sp', '--save_path', dest='save_path', type=str, default='/mnt/disk1/dat/lchen63/brain/data/result/')
49
50
51
    return vars(parser.parse_args())
52
53
54
options = parse_inputs()
55
os.environ["CUDA_VISIBLE_DEVICES"] = options['gpu']
56
57
58
59
def vox_preprocess(vox):
60
    vox_shape = vox.shape
61
    vox = np.reshape(vox, (-1, vox_shape[-1]))
62
    vox = scale(vox, axis=0)
63
    return np.reshape(vox, vox_shape)
64
65
66
def one_hot(y, num_classees):
67
    y_ = np.zeros([len(y), num_classees])
68
    y_[np.arange(len(y)), y] = 1
69
    return y_
70
71
72
def dice_coef_np(y_true, y_pred, num_classes):
73
    """
74
75
    :param y_true: sparse labels
76
    :param y_pred: sparse labels
77
    :param num_classes: number of classes
78
    :return:
79
    """
80
    y_true = y_true.astype(int)
81
    y_pred = y_pred.astype(int)
82
    y_true = y_true.flatten()
83
    y_true = one_hot(y_true, num_classes)
84
    y_pred = y_pred.flatten()
85
    y_pred = one_hot(y_pred, num_classes)
86
    intersection = np.sum(y_true * y_pred, axis=0)
87
    return (2. * intersection) / (np.sum(y_true, axis=0) + np.sum(y_pred, axis=0))
88
89
90
91
92
93
94
95
96
97
def norm(image):
98
    image = np.squeeze(image)
99
    image_nonzero = image[np.nonzero(image)]
100
    return (image - image_nonzero.mean()) / image_nonzero.std()
101
102
103
def vox_generator_test(all_files):
104
105
    path = options['root_path']
106
107
    while 1:
108
        for file in all_files:
109
            p = file
110
            if options['correction']:
111
                flair = load_nii(os.path.join(path, file, file + '_flair_corrected.nii.gz')).get_data()
112
                t2 = load_nii(os.path.join(path, file, file + '_t2_corrected.nii.gz')).get_data()
113
                t1 = load_nii(os.path.join(path, file, file + '_t1_corrected.nii.gz')).get_data()
114
                t1ce = load_nii(os.path.join(path, file, file + '_t1ce_corrected.nii.gz')).get_data()
115
            else:
116
                flair = load_nii(os.path.join(path, p, p + '_flair.nii.gz')).get_data()
117
118
                t2 = load_nii(os.path.join(path, p, p + '_t2.nii.gz')).get_data()
119
120
                t1 = load_nii(os.path.join(path, p, p + '_t1.nii.gz')).get_data()
121
122
                t1ce = load_nii(os.path.join(path, p, p + '_t1ce.nii.gz')).get_data()
123
            data = np.array([flair, t2, t1, t1ce])
124
            data = np.transpose(data, axes=[1, 2, 3, 0])
125
126
            data_norm = np.array([norm(flair), norm(t2), norm(t1), norm(t1ce)])
127
            data_norm = np.transpose(data_norm, axes=[1, 2, 3, 0])
128
129
            labels = load_nii(os.path.join(path, p, p + '_seg.nii.gz')).get_data()
130
131
            yield data, data_norm, labels
132
133
134
135
def main():
136
    test_files = []
137
    with open('test.txt') as f:
138
        for line in f:
139
            test_files.append(line[:-1])
140
141
    num_labels = 5
142
    OFFSET_H = options['offset_h']
143
    OFFSET_W = options['offset_w']
144
    OFFSET_C = options['offset_c']
145
    HSIZE = options['hsize']
146
    WSIZE = options['wsize']
147
    CSIZE = options['csize']
148
    PSIZE = options['psize']
149
    SAVE_PATH = options['model_path']
150
    model_name = options['model_name']
151
152
    OFFSET_PH = (HSIZE - PSIZE) / 2
153
    OFFSET_PW = (WSIZE - PSIZE) / 2
154
    OFFSET_PC = (CSIZE - PSIZE) / 2
155
156
    batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1
157
    batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1
158
    batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1
159
160
161
162
    flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
163
    t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
164
165
166
    if model_name == 'dense48':
167
        
168
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair')
169
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1')
170
    elif model_name == 'no_dense':
171
172
        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair')
173
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1')
174
175
    elif model_name == 'dense24':
176
177
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair')
178
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1')
179
180
    elif model_name == 'dense24_nocorrection':
181
182
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair')
183
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1')
184
185
    else:
186
        print' No such model name '
187
188
    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
189
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])
190
191
    t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15)
192
    t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27)
193
194
    t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \
195
                    t1_t1ce_27[:, 13:25, 13:25, 13:25, :]
196
197
198
    saver = tf.train.Saver()
199
    data_gen_test = vox_generator_test(test_files)
200
    dice_whole, dice_core, dice_et = [], [], []
201
    with tf.Session() as sess:
202
        saver.restore(sess, SAVE_PATH)
203
        for i in range(len(test_files)):
204
            print 'predicting %s' % test_files[i]
205
            x, x_n, y = data_gen_test.next()
206
            pred = np.zeros([240, 240, 155, 5])
207
            for hi in range(batches_h):
208
                offset_h = min(OFFSET_H * hi, 240 - HSIZE)
209
                offset_ph = offset_h + OFFSET_PH
210
                for wi in range(batches_w):
211
                    offset_w = min(OFFSET_W * wi, 240 - WSIZE)
212
                    offset_pw = offset_w + OFFSET_PW
213
                    for ci in range(batches_c):
214
                        offset_c = min(OFFSET_C * ci, 155 - CSIZE)
215
                        offset_pc = offset_c + OFFSET_PC
216
                        data = x[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :]
217
                        data_norm = x_n[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :]
218
                        data_norm = np.expand_dims(data_norm, 0)
219
                        if not np.max(data) == 0 and np.min(data) == 0:
220
                            score = sess.run(fetches=t1_t1ce_score,
221
                                             feed_dict={flair_t2_node: data_norm[:, :, :, :, :2],
222
                                                        t1_t1ce_node: data_norm[:, :, :, :, 2:],
223
                                                        learning_phase(): 0}
224
                                             )
225
                            pred[offset_ph:offset_ph + PSIZE, offset_pw:offset_pw + PSIZE, offset_pc:offset_pc + PSIZE,
226
                            :] += np.squeeze(score)
227
228
            pred = np.argmax(pred, axis=-1)
229
            pred = pred.astype(int)
230
            print 'calculating dice...'
231
            print  options['save_path'] +  test_files[i] +'_prediction'
232
            np.save(options['save_path'] +  test_files[i] +'_prediction',pred)
233
            whole_pred = (pred > 0).astype(int)
234
            whole_gt = (y > 0).astype(int)
235
            core_pred = (pred == 1).astype(int) + (pred == 4).astype(int)
236
            core_gt = (y == 1).astype(int) + (y == 4).astype(int)
237
            et_pred = (pred == 4).astype(int)
238
            et_gt = (y == 4).astype(int)
239
            dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2)
240
            dice_core_batch = dice_coef_np(core_gt, core_pred, 2)
241
            dice_et_batch = dice_coef_np(et_gt, et_pred, 2)
242
            dice_whole.append(dice_whole_batch)
243
            dice_core.append(dice_core_batch)
244
            dice_et.append(dice_et_batch)
245
            print dice_whole_batch
246
            print dice_core_batch
247
            print dice_et_batch
248
249
        dice_whole = np.array(dice_whole)
250
        dice_core = np.array(dice_core)
251
        dice_et = np.array(dice_et)
252
253
        print 'mean dice whole:'
254
        print np.mean(dice_whole, axis=0)
255
        print 'mean dice core:'
256
        print np.mean(dice_core, axis=0)
257
        print 'mean dice enhance:'
258
        print np.mean(dice_et, axis=0)
259
260
        np.save(model_name + '_dice_whole', dice_whole)
261
        np.save(model_name + '_dice_core', dice_core)
262
        np.save(model_name + '_dice_enhance', dice_et)
263
        print 'pred saved'
264
265
266
if __name__ == '__main__':
267
    main()