Diff of /test.py [000000] .. [e44b03]

Switch to unified view

a b/test.py
1
import numpy as np
2
import keras
3
import argparse
4
import os
5
import tf_models
6
import tensorflow as tf
7
from keras.models import Sequential, Model
8
from keras.layers import Dense, Conv3D, Dropout, Flatten, Input, concatenate, Reshape, Lambda, Permute
9
from keras.layers.core import Dense, Dropout, Activation, Reshape
10
from keras.layers.convolutional import Conv3D, Conv3DTranspose, UpSampling3D
11
from keras.layers.pooling import AveragePooling3D
12
from keras.layers import Input
13
from keras.layers.merge import concatenate
14
from keras.layers.normalization import BatchNormalization
15
from tensorflow.contrib.keras.python.keras.backend import learning_phase
16
17
from nibabel import load as load_nii
18
from sklearn.preprocessing import scale
19
import matplotlib.pyplot as plt
20
21
22
23
# SAVE_PATH = 'unet3d_baseline.hdf5'
24
# OFFSET_W = 16
25
# OFFSET_H = 16
26
# OFFSET_C = 4
27
# HSIZE = 64
28
# WSIZE = 64
29
# CSIZE = 16
30
# batches_h, batches_w, batches_c = (224-HSIZE)/OFFSET_H+1, (224-WSIZE)/OFFSET_W+1, (152 - CSIZE)/OFFSET_C+1
31
32
33
def parse_inputs():
34
    parser = argparse.ArgumentParser(description='Test different nets with 3D data.')
35
    parser.add_argument('-r', '--root-path', dest='root_path', default='/mnt/disk1/dat/lchen63/brain/data/data2')
36
    parser.add_argument('-m', '--model-path', dest='model_path',
37
                        default='NoneDense-0')
38
    parser.add_argument('-ow', '--offset-width', dest='offset_w', type=int, default=12)
39
    parser.add_argument('-oh', '--offset-height', dest='offset_h', type=int, default=12)
40
    parser.add_argument('-oc', '--offset-channel', dest='offset_c', nargs='+', type=int, default=12)
41
    parser.add_argument('-ws', '--width-size', dest='wsize', type=int, default=38)
42
    parser.add_argument('-hs', '--height-size', dest='hsize', type=int, default=38)
43
    parser.add_argument('-cs', '--channel-size', dest='csize', type=int, default=38)
44
    parser.add_argument('-ps', '--pred-size', dest='psize', type=int, default=12)
45
    parser.add_argument('-gpu', '--gpu', dest='gpu', type=str, default='0')
46
    parser.add_argument('-mn', '--model_name', dest='model_name', type=str, default='dense24')
47
    parser.add_argument('-nc', '--correction', dest='correction', type=bool, default=True)
48
49
50
    return vars(parser.parse_args())
51
52
53
options = parse_inputs()
54
os.environ["CUDA_VISIBLE_DEVICES"] = options['gpu']
55
56
57
def segmentation_loss(y_true, y_pred, n_classes):
58
    y_true = tf.reshape(y_true, (-1, n_classes))
59
    y_pred = tf.reshape(y_pred, (-1, n_classes))
60
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,
61
                                                                  logits=y_pred))
62
63
64
def vox_preprocess(vox):
65
    vox_shape = vox.shape
66
    vox = np.reshape(vox, (-1, vox_shape[-1]))
67
    vox = scale(vox, axis=0)
68
    return np.reshape(vox, vox_shape)
69
70
71
def one_hot(y, num_classees):
72
    y_ = np.zeros([len(y), num_classees])
73
    y_[np.arange(len(y)), y] = 1
74
    return y_
75
76
77
def dice_coef_np(y_true, y_pred, num_classes):
78
    """
79
80
    :param y_true: sparse labels
81
    :param y_pred: sparse labels
82
    :param num_classes: number of classes
83
    :return:
84
    """
85
    y_true = y_true.astype(int)
86
    y_pred = y_pred.astype(int)
87
    y_true = y_true.flatten()
88
    y_true = one_hot(y_true, num_classes)
89
    y_pred = y_pred.flatten()
90
    y_pred = one_hot(y_pred, num_classes)
91
    intersection = np.sum(y_true * y_pred, axis=0)
92
    return (2. * intersection) / (np.sum(y_true, axis=0) + np.sum(y_pred, axis=0))
93
94
95
def DenseNetUnit3D(x, growth_rate, ksize, n, bn_decay=0.99):
96
    for i in range(n):
97
        concat = x
98
        x = BatchNormalization(center=True, scale=True, momentum=bn_decay)(x)
99
        x = Activation('relu')(x)
100
        x = Conv3D(filters=growth_rate, kernel_size=ksize, padding='same', kernel_initializer='he_uniform',
101
                   use_bias=False)(x)
102
        x = concatenate([concat, x])
103
    return x
104
105
106
def DenseNetTransit(x, rate=1, name=None):
107
    if rate != 1:
108
        out_features = x.get_shape().as_list()[-1] * rate
109
        x = BatchNormalization(center=True, scale=True, name=name + '_bn')(x)
110
        x = Activation('relu', name=name + '_relu')(x)
111
        x = Conv3D(filters=out_features, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal',
112
                   use_bias=False, name=name + '_conv')(x)
113
    x = AveragePooling3D(pool_size=2, strides=2, padding='same')(x)
114
    return x
115
116
117
def dense_net(input):
118
    x = Conv3D(filters=24, kernel_size=3, strides=1, kernel_initializer='he_uniform', padding='same', use_bias=False)(
119
        input)
120
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, n=4)
121
    x = DenseNetTransit(x)
122
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, n=4)
123
    x = DenseNetTransit(x)
124
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, n=4)
125
    x = BatchNormalization()(x)
126
    x = Activation('relu')(x)
127
    return x
128
129
130
def dense_model(patch_size, num_classes):
131
    merged_inputs = Input(shape=patch_size + (4,), name='merged_inputs')
132
    flair = Reshape(patch_size + (1,))(
133
        Lambda(
134
            lambda l: l[:, :, :, :, 0],
135
            output_shape=patch_size + (1,))(merged_inputs),
136
    )
137
    t2 = Reshape(patch_size + (1,))(
138
        Lambda(lambda l: l[:, :, :, :, 1], output_shape=patch_size + (1,))(merged_inputs)
139
    )
140
    t1 = Lambda(lambda l: l[:, :, :, :, 2:], output_shape=patch_size + (2,))(merged_inputs)
141
142
    flair = dense_net(flair)
143
    t2 = dense_net(t2)
144
    t1 = dense_net(t1)
145
146
    t2 = concatenate([flair, t2])
147
148
    t1 = concatenate([t2, t1])
149
150
    tumor = Conv3D(2, kernel_size=1, strides=1, name='tumor')(flair)
151
    core = Conv3D(3, kernel_size=1, strides=1, name='core')(t2)
152
    enhancing = Conv3D(num_classes, kernel_size=1, strides=1, name='enhancing')(t1)
153
    net = Model(inputs=merged_inputs, outputs=[tumor, core, enhancing])
154
155
    return net
156
157
158
def norm(image):
159
    image = np.squeeze(image)
160
    image_nonzero = image[np.nonzero(image)]
161
    return (image - image_nonzero.mean()) / image_nonzero.std()
162
163
164
def vox_generator_test(all_files):
165
166
    path = options['root_path']
167
168
    while 1:
169
        for file in all_files:
170
            p = file
171
            if options['correction']:
172
                flair = load_nii(os.path.join(path, file, file + '_flair_corrected.nii.gz')).get_data()
173
                t2 = load_nii(os.path.join(path, file, file + '_t2_corrected.nii.gz')).get_data()
174
                t1 = load_nii(os.path.join(path, file, file + '_t1_corrected.nii.gz')).get_data()
175
                t1ce = load_nii(os.path.join(path, file, file + '_t1ce_corrected.nii.gz')).get_data()
176
            else:
177
                flair = load_nii(os.path.join(path, p, p + '_flair.nii.gz')).get_data()
178
179
                t2 = load_nii(os.path.join(path, p, p + '_t2.nii.gz')).get_data()
180
181
                t1 = load_nii(os.path.join(path, p, p + '_t1.nii.gz')).get_data()
182
183
                t1ce = load_nii(os.path.join(path, p, p + '_t1ce.nii.gz')).get_data()
184
            data = np.array([flair, t2, t1, t1ce])
185
            data = np.transpose(data, axes=[1, 2, 3, 0])
186
187
            data_norm = np.array([norm(flair), norm(t2), norm(t1), norm(t1ce)])
188
            data_norm = np.transpose(data_norm, axes=[1, 2, 3, 0])
189
190
            labels = load_nii(os.path.join(path, p, p + '_seg.nii.gz')).get_data()
191
192
            yield data, data_norm, labels
193
194
195
196
def main():
197
    test_files = []
198
    with open('test.txt') as f:
199
        for line in f:
200
            test_files.append(line[:-1])
201
202
    num_labels = 5
203
    OFFSET_H = options['offset_h']
204
    OFFSET_W = options['offset_w']
205
    OFFSET_C = options['offset_c']
206
    HSIZE = options['hsize']
207
    WSIZE = options['wsize']
208
    CSIZE = options['csize']
209
    PSIZE = options['psize']
210
    SAVE_PATH = options['model_path']
211
    model_name = options['model_name']
212
213
    OFFSET_PH = (HSIZE - PSIZE) / 2
214
    OFFSET_PW = (WSIZE - PSIZE) / 2
215
    OFFSET_PC = (CSIZE - PSIZE) / 2
216
217
    batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1
218
    batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1
219
    batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1
220
221
222
223
    flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
224
    t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
225
226
227
    if model_name == 'dense48':
228
        
229
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair')
230
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1')
231
    elif model_name == 'no_dense':
232
233
        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair')
234
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1')
235
236
    elif model_name == 'dense24':
237
238
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair')
239
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1')
240
241
    elif model_name == 'dense24_nocorrection':
242
243
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair')
244
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1')
245
246
    else:
247
        print' No such model name '
248
249
    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
250
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])
251
252
    t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15)
253
    t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27)
254
255
    t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \
256
                    t1_t1ce_27[:, 13:25, 13:25, 13:25, :]
257
258
259
    saver = tf.train.Saver()
260
    data_gen_test = vox_generator_test(test_files)
261
    dice_whole, dice_core, dice_et = [], [], []
262
    with tf.Session() as sess:
263
        saver.restore(sess, SAVE_PATH)
264
        for i in range(len(test_files)):
265
            print 'predicting %s' % test_files[i]
266
            x, x_n, y = data_gen_test.next()
267
            pred = np.zeros([240, 240, 155, 5])
268
            for hi in range(batches_h):
269
                offset_h = min(OFFSET_H * hi, 240 - HSIZE)
270
                offset_ph = offset_h + OFFSET_PH
271
                for wi in range(batches_w):
272
                    offset_w = min(OFFSET_W * wi, 240 - WSIZE)
273
                    offset_pw = offset_w + OFFSET_PW
274
                    for ci in range(batches_c):
275
                        offset_c = min(OFFSET_C * ci, 155 - CSIZE)
276
                        offset_pc = offset_c + OFFSET_PC
277
                        data = x[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :]
278
                        data_norm = x_n[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :]
279
                        data_norm = np.expand_dims(data_norm, 0)
280
                        if not np.max(data) == 0 and np.min(data) == 0:
281
                            score = sess.run(fetches=t1_t1ce_score,
282
                                             feed_dict={flair_t2_node: data_norm[:, :, :, :, :2],
283
                                                        t1_t1ce_node: data_norm[:, :, :, :, 2:],
284
                                                        learning_phase(): 0}
285
                                             )
286
                            pred[offset_ph:offset_ph + PSIZE, offset_pw:offset_pw + PSIZE, offset_pc:offset_pc + PSIZE,
287
                            :] += np.squeeze(score)
288
289
            pred = np.argmax(pred, axis=-1)
290
            pred = pred.astype(int)
291
            print 'calculating dice...'
292
            whole_pred = (pred > 0).astype(int)
293
            whole_gt = (y > 0).astype(int)
294
            core_pred = (pred == 1).astype(int) + (pred == 4).astype(int)
295
            core_gt = (y == 1).astype(int) + (y == 4).astype(int)
296
            et_pred = (pred == 4).astype(int)
297
            et_gt = (y == 4).astype(int)
298
            dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2)
299
            dice_core_batch = dice_coef_np(core_gt, core_pred, 2)
300
            dice_et_batch = dice_coef_np(et_gt, et_pred, 2)
301
            dice_whole.append(dice_whole_batch)
302
            dice_core.append(dice_core_batch)
303
            dice_et.append(dice_et_batch)
304
            print dice_whole_batch
305
            print dice_core_batch
306
            print dice_et_batch
307
308
        dice_whole = np.array(dice_whole)
309
        dice_core = np.array(dice_core)
310
        dice_et = np.array(dice_et)
311
312
        print 'mean dice whole:'
313
        print np.mean(dice_whole, axis=0)
314
        print 'mean dice core:'
315
        print np.mean(dice_core, axis=0)
316
        print 'mean dice enhance:'
317
        print np.mean(dice_et, axis=0)
318
319
        np.save(model_name + '_dice_whole', dice_whole)
320
        np.save(model_name + '_dice_core', dice_core)
321
        np.save(model_name + '_dice_enhance', dice_et)
322
        print 'pred saved'
323
324
325
if __name__ == '__main__':
326
    main()