|
a |
|
b/test.py |
|
|
1 |
import numpy as np |
|
|
2 |
import keras |
|
|
3 |
import argparse |
|
|
4 |
import os |
|
|
5 |
import tf_models |
|
|
6 |
import tensorflow as tf |
|
|
7 |
from keras.models import Sequential, Model |
|
|
8 |
from keras.layers import Dense, Conv3D, Dropout, Flatten, Input, concatenate, Reshape, Lambda, Permute |
|
|
9 |
from keras.layers.core import Dense, Dropout, Activation, Reshape |
|
|
10 |
from keras.layers.convolutional import Conv3D, Conv3DTranspose, UpSampling3D |
|
|
11 |
from keras.layers.pooling import AveragePooling3D |
|
|
12 |
from keras.layers import Input |
|
|
13 |
from keras.layers.merge import concatenate |
|
|
14 |
from keras.layers.normalization import BatchNormalization |
|
|
15 |
from tensorflow.contrib.keras.python.keras.backend import learning_phase |
|
|
16 |
|
|
|
17 |
from nibabel import load as load_nii |
|
|
18 |
from sklearn.preprocessing import scale |
|
|
19 |
import matplotlib.pyplot as plt |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
|
|
|
23 |
# SAVE_PATH = 'unet3d_baseline.hdf5' |
|
|
24 |
# OFFSET_W = 16 |
|
|
25 |
# OFFSET_H = 16 |
|
|
26 |
# OFFSET_C = 4 |
|
|
27 |
# HSIZE = 64 |
|
|
28 |
# WSIZE = 64 |
|
|
29 |
# CSIZE = 16 |
|
|
30 |
# batches_h, batches_w, batches_c = (224-HSIZE)/OFFSET_H+1, (224-WSIZE)/OFFSET_W+1, (152 - CSIZE)/OFFSET_C+1 |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
def parse_inputs(): |
|
|
34 |
parser = argparse.ArgumentParser(description='Test different nets with 3D data.') |
|
|
35 |
parser.add_argument('-r', '--root-path', dest='root_path', default='/mnt/disk1/dat/lchen63/brain/data/data2') |
|
|
36 |
parser.add_argument('-m', '--model-path', dest='model_path', |
|
|
37 |
default='NoneDense-0') |
|
|
38 |
parser.add_argument('-ow', '--offset-width', dest='offset_w', type=int, default=12) |
|
|
39 |
parser.add_argument('-oh', '--offset-height', dest='offset_h', type=int, default=12) |
|
|
40 |
parser.add_argument('-oc', '--offset-channel', dest='offset_c', nargs='+', type=int, default=12) |
|
|
41 |
parser.add_argument('-ws', '--width-size', dest='wsize', type=int, default=38) |
|
|
42 |
parser.add_argument('-hs', '--height-size', dest='hsize', type=int, default=38) |
|
|
43 |
parser.add_argument('-cs', '--channel-size', dest='csize', type=int, default=38) |
|
|
44 |
parser.add_argument('-ps', '--pred-size', dest='psize', type=int, default=12) |
|
|
45 |
parser.add_argument('-gpu', '--gpu', dest='gpu', type=str, default='0') |
|
|
46 |
parser.add_argument('-mn', '--model_name', dest='model_name', type=str, default='dense24') |
|
|
47 |
parser.add_argument('-nc', '--correction', dest='correction', type=bool, default=True) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
return vars(parser.parse_args()) |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
options = parse_inputs() |
|
|
54 |
os.environ["CUDA_VISIBLE_DEVICES"] = options['gpu'] |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def segmentation_loss(y_true, y_pred, n_classes): |
|
|
58 |
y_true = tf.reshape(y_true, (-1, n_classes)) |
|
|
59 |
y_pred = tf.reshape(y_pred, (-1, n_classes)) |
|
|
60 |
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, |
|
|
61 |
logits=y_pred)) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def vox_preprocess(vox): |
|
|
65 |
vox_shape = vox.shape |
|
|
66 |
vox = np.reshape(vox, (-1, vox_shape[-1])) |
|
|
67 |
vox = scale(vox, axis=0) |
|
|
68 |
return np.reshape(vox, vox_shape) |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def one_hot(y, num_classees): |
|
|
72 |
y_ = np.zeros([len(y), num_classees]) |
|
|
73 |
y_[np.arange(len(y)), y] = 1 |
|
|
74 |
return y_ |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
def dice_coef_np(y_true, y_pred, num_classes): |
|
|
78 |
""" |
|
|
79 |
|
|
|
80 |
:param y_true: sparse labels |
|
|
81 |
:param y_pred: sparse labels |
|
|
82 |
:param num_classes: number of classes |
|
|
83 |
:return: |
|
|
84 |
""" |
|
|
85 |
y_true = y_true.astype(int) |
|
|
86 |
y_pred = y_pred.astype(int) |
|
|
87 |
y_true = y_true.flatten() |
|
|
88 |
y_true = one_hot(y_true, num_classes) |
|
|
89 |
y_pred = y_pred.flatten() |
|
|
90 |
y_pred = one_hot(y_pred, num_classes) |
|
|
91 |
intersection = np.sum(y_true * y_pred, axis=0) |
|
|
92 |
return (2. * intersection) / (np.sum(y_true, axis=0) + np.sum(y_pred, axis=0)) |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
def DenseNetUnit3D(x, growth_rate, ksize, n, bn_decay=0.99): |
|
|
96 |
for i in range(n): |
|
|
97 |
concat = x |
|
|
98 |
x = BatchNormalization(center=True, scale=True, momentum=bn_decay)(x) |
|
|
99 |
x = Activation('relu')(x) |
|
|
100 |
x = Conv3D(filters=growth_rate, kernel_size=ksize, padding='same', kernel_initializer='he_uniform', |
|
|
101 |
use_bias=False)(x) |
|
|
102 |
x = concatenate([concat, x]) |
|
|
103 |
return x |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def DenseNetTransit(x, rate=1, name=None): |
|
|
107 |
if rate != 1: |
|
|
108 |
out_features = x.get_shape().as_list()[-1] * rate |
|
|
109 |
x = BatchNormalization(center=True, scale=True, name=name + '_bn')(x) |
|
|
110 |
x = Activation('relu', name=name + '_relu')(x) |
|
|
111 |
x = Conv3D(filters=out_features, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', |
|
|
112 |
use_bias=False, name=name + '_conv')(x) |
|
|
113 |
x = AveragePooling3D(pool_size=2, strides=2, padding='same')(x) |
|
|
114 |
return x |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
def dense_net(input): |
|
|
118 |
x = Conv3D(filters=24, kernel_size=3, strides=1, kernel_initializer='he_uniform', padding='same', use_bias=False)( |
|
|
119 |
input) |
|
|
120 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, n=4) |
|
|
121 |
x = DenseNetTransit(x) |
|
|
122 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, n=4) |
|
|
123 |
x = DenseNetTransit(x) |
|
|
124 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, n=4) |
|
|
125 |
x = BatchNormalization()(x) |
|
|
126 |
x = Activation('relu')(x) |
|
|
127 |
return x |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
def dense_model(patch_size, num_classes): |
|
|
131 |
merged_inputs = Input(shape=patch_size + (4,), name='merged_inputs') |
|
|
132 |
flair = Reshape(patch_size + (1,))( |
|
|
133 |
Lambda( |
|
|
134 |
lambda l: l[:, :, :, :, 0], |
|
|
135 |
output_shape=patch_size + (1,))(merged_inputs), |
|
|
136 |
) |
|
|
137 |
t2 = Reshape(patch_size + (1,))( |
|
|
138 |
Lambda(lambda l: l[:, :, :, :, 1], output_shape=patch_size + (1,))(merged_inputs) |
|
|
139 |
) |
|
|
140 |
t1 = Lambda(lambda l: l[:, :, :, :, 2:], output_shape=patch_size + (2,))(merged_inputs) |
|
|
141 |
|
|
|
142 |
flair = dense_net(flair) |
|
|
143 |
t2 = dense_net(t2) |
|
|
144 |
t1 = dense_net(t1) |
|
|
145 |
|
|
|
146 |
t2 = concatenate([flair, t2]) |
|
|
147 |
|
|
|
148 |
t1 = concatenate([t2, t1]) |
|
|
149 |
|
|
|
150 |
tumor = Conv3D(2, kernel_size=1, strides=1, name='tumor')(flair) |
|
|
151 |
core = Conv3D(3, kernel_size=1, strides=1, name='core')(t2) |
|
|
152 |
enhancing = Conv3D(num_classes, kernel_size=1, strides=1, name='enhancing')(t1) |
|
|
153 |
net = Model(inputs=merged_inputs, outputs=[tumor, core, enhancing]) |
|
|
154 |
|
|
|
155 |
return net |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
def norm(image): |
|
|
159 |
image = np.squeeze(image) |
|
|
160 |
image_nonzero = image[np.nonzero(image)] |
|
|
161 |
return (image - image_nonzero.mean()) / image_nonzero.std() |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
def vox_generator_test(all_files): |
|
|
165 |
|
|
|
166 |
path = options['root_path'] |
|
|
167 |
|
|
|
168 |
while 1: |
|
|
169 |
for file in all_files: |
|
|
170 |
p = file |
|
|
171 |
if options['correction']: |
|
|
172 |
flair = load_nii(os.path.join(path, file, file + '_flair_corrected.nii.gz')).get_data() |
|
|
173 |
t2 = load_nii(os.path.join(path, file, file + '_t2_corrected.nii.gz')).get_data() |
|
|
174 |
t1 = load_nii(os.path.join(path, file, file + '_t1_corrected.nii.gz')).get_data() |
|
|
175 |
t1ce = load_nii(os.path.join(path, file, file + '_t1ce_corrected.nii.gz')).get_data() |
|
|
176 |
else: |
|
|
177 |
flair = load_nii(os.path.join(path, p, p + '_flair.nii.gz')).get_data() |
|
|
178 |
|
|
|
179 |
t2 = load_nii(os.path.join(path, p, p + '_t2.nii.gz')).get_data() |
|
|
180 |
|
|
|
181 |
t1 = load_nii(os.path.join(path, p, p + '_t1.nii.gz')).get_data() |
|
|
182 |
|
|
|
183 |
t1ce = load_nii(os.path.join(path, p, p + '_t1ce.nii.gz')).get_data() |
|
|
184 |
data = np.array([flair, t2, t1, t1ce]) |
|
|
185 |
data = np.transpose(data, axes=[1, 2, 3, 0]) |
|
|
186 |
|
|
|
187 |
data_norm = np.array([norm(flair), norm(t2), norm(t1), norm(t1ce)]) |
|
|
188 |
data_norm = np.transpose(data_norm, axes=[1, 2, 3, 0]) |
|
|
189 |
|
|
|
190 |
labels = load_nii(os.path.join(path, p, p + '_seg.nii.gz')).get_data() |
|
|
191 |
|
|
|
192 |
yield data, data_norm, labels |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
|
|
|
196 |
def main(): |
|
|
197 |
test_files = [] |
|
|
198 |
with open('test.txt') as f: |
|
|
199 |
for line in f: |
|
|
200 |
test_files.append(line[:-1]) |
|
|
201 |
|
|
|
202 |
num_labels = 5 |
|
|
203 |
OFFSET_H = options['offset_h'] |
|
|
204 |
OFFSET_W = options['offset_w'] |
|
|
205 |
OFFSET_C = options['offset_c'] |
|
|
206 |
HSIZE = options['hsize'] |
|
|
207 |
WSIZE = options['wsize'] |
|
|
208 |
CSIZE = options['csize'] |
|
|
209 |
PSIZE = options['psize'] |
|
|
210 |
SAVE_PATH = options['model_path'] |
|
|
211 |
model_name = options['model_name'] |
|
|
212 |
|
|
|
213 |
OFFSET_PH = (HSIZE - PSIZE) / 2 |
|
|
214 |
OFFSET_PW = (WSIZE - PSIZE) / 2 |
|
|
215 |
OFFSET_PC = (CSIZE - PSIZE) / 2 |
|
|
216 |
|
|
|
217 |
batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1 |
|
|
218 |
batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1 |
|
|
219 |
batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1 |
|
|
220 |
|
|
|
221 |
|
|
|
222 |
|
|
|
223 |
flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) |
|
|
224 |
t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
if model_name == 'dense48': |
|
|
228 |
|
|
|
229 |
flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair') |
|
|
230 |
t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1') |
|
|
231 |
elif model_name == 'no_dense': |
|
|
232 |
|
|
|
233 |
flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair') |
|
|
234 |
t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') |
|
|
235 |
|
|
|
236 |
elif model_name == 'dense24': |
|
|
237 |
|
|
|
238 |
flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair') |
|
|
239 |
t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1') |
|
|
240 |
|
|
|
241 |
elif model_name == 'dense24_nocorrection': |
|
|
242 |
|
|
|
243 |
flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair') |
|
|
244 |
t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1') |
|
|
245 |
|
|
|
246 |
else: |
|
|
247 |
print' No such model name ' |
|
|
248 |
|
|
|
249 |
t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) |
|
|
250 |
t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) |
|
|
251 |
|
|
|
252 |
t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) |
|
|
253 |
t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) |
|
|
254 |
|
|
|
255 |
t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ |
|
|
256 |
t1_t1ce_27[:, 13:25, 13:25, 13:25, :] |
|
|
257 |
|
|
|
258 |
|
|
|
259 |
saver = tf.train.Saver() |
|
|
260 |
data_gen_test = vox_generator_test(test_files) |
|
|
261 |
dice_whole, dice_core, dice_et = [], [], [] |
|
|
262 |
with tf.Session() as sess: |
|
|
263 |
saver.restore(sess, SAVE_PATH) |
|
|
264 |
for i in range(len(test_files)): |
|
|
265 |
print 'predicting %s' % test_files[i] |
|
|
266 |
x, x_n, y = data_gen_test.next() |
|
|
267 |
pred = np.zeros([240, 240, 155, 5]) |
|
|
268 |
for hi in range(batches_h): |
|
|
269 |
offset_h = min(OFFSET_H * hi, 240 - HSIZE) |
|
|
270 |
offset_ph = offset_h + OFFSET_PH |
|
|
271 |
for wi in range(batches_w): |
|
|
272 |
offset_w = min(OFFSET_W * wi, 240 - WSIZE) |
|
|
273 |
offset_pw = offset_w + OFFSET_PW |
|
|
274 |
for ci in range(batches_c): |
|
|
275 |
offset_c = min(OFFSET_C * ci, 155 - CSIZE) |
|
|
276 |
offset_pc = offset_c + OFFSET_PC |
|
|
277 |
data = x[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :] |
|
|
278 |
data_norm = x_n[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :] |
|
|
279 |
data_norm = np.expand_dims(data_norm, 0) |
|
|
280 |
if not np.max(data) == 0 and np.min(data) == 0: |
|
|
281 |
score = sess.run(fetches=t1_t1ce_score, |
|
|
282 |
feed_dict={flair_t2_node: data_norm[:, :, :, :, :2], |
|
|
283 |
t1_t1ce_node: data_norm[:, :, :, :, 2:], |
|
|
284 |
learning_phase(): 0} |
|
|
285 |
) |
|
|
286 |
pred[offset_ph:offset_ph + PSIZE, offset_pw:offset_pw + PSIZE, offset_pc:offset_pc + PSIZE, |
|
|
287 |
:] += np.squeeze(score) |
|
|
288 |
|
|
|
289 |
pred = np.argmax(pred, axis=-1) |
|
|
290 |
pred = pred.astype(int) |
|
|
291 |
print 'calculating dice...' |
|
|
292 |
whole_pred = (pred > 0).astype(int) |
|
|
293 |
whole_gt = (y > 0).astype(int) |
|
|
294 |
core_pred = (pred == 1).astype(int) + (pred == 4).astype(int) |
|
|
295 |
core_gt = (y == 1).astype(int) + (y == 4).astype(int) |
|
|
296 |
et_pred = (pred == 4).astype(int) |
|
|
297 |
et_gt = (y == 4).astype(int) |
|
|
298 |
dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2) |
|
|
299 |
dice_core_batch = dice_coef_np(core_gt, core_pred, 2) |
|
|
300 |
dice_et_batch = dice_coef_np(et_gt, et_pred, 2) |
|
|
301 |
dice_whole.append(dice_whole_batch) |
|
|
302 |
dice_core.append(dice_core_batch) |
|
|
303 |
dice_et.append(dice_et_batch) |
|
|
304 |
print dice_whole_batch |
|
|
305 |
print dice_core_batch |
|
|
306 |
print dice_et_batch |
|
|
307 |
|
|
|
308 |
dice_whole = np.array(dice_whole) |
|
|
309 |
dice_core = np.array(dice_core) |
|
|
310 |
dice_et = np.array(dice_et) |
|
|
311 |
|
|
|
312 |
print 'mean dice whole:' |
|
|
313 |
print np.mean(dice_whole, axis=0) |
|
|
314 |
print 'mean dice core:' |
|
|
315 |
print np.mean(dice_core, axis=0) |
|
|
316 |
print 'mean dice enhance:' |
|
|
317 |
print np.mean(dice_et, axis=0) |
|
|
318 |
|
|
|
319 |
np.save(model_name + '_dice_whole', dice_whole) |
|
|
320 |
np.save(model_name + '_dice_core', dice_core) |
|
|
321 |
np.save(model_name + '_dice_enhance', dice_et) |
|
|
322 |
print 'pred saved' |
|
|
323 |
|
|
|
324 |
|
|
|
325 |
if __name__ == '__main__': |
|
|
326 |
main() |