|
a |
|
b/testvis.py |
|
|
1 |
""" |
|
|
2 |
This code is to test NN model and visualize output |
|
|
3 |
""" |
|
|
4 |
import numpy as np |
|
|
5 |
import sys |
|
|
6 |
import time |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
|
|
|
9 |
from keras.models import Model, load_model |
|
|
10 |
from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, ZeroPadding2D, BatchNormalization |
|
|
11 |
from keras.optimizers import Adam, SGD |
|
|
12 |
from keras.callbacks import ModelCheckpoint |
|
|
13 |
from keras import backend as K |
|
|
14 |
import tensorflow as tf |
|
|
15 |
|
|
|
16 |
from data import load_train_data, load_test_data |
|
|
17 |
from utils import * |
|
|
18 |
|
|
|
19 |
K.set_image_data_format('channels_last') # Tensorflow dimension ordering |
|
|
20 |
|
|
|
21 |
data_path = sys.argv[1] + "/" |
|
|
22 |
model_path = data_path + "models/" |
|
|
23 |
|
|
|
24 |
# dir for storing results that contains |
|
|
25 |
rst_path = data_path + "test-records/" |
|
|
26 |
if not os.path.exists(rst_path): |
|
|
27 |
os.makedirs(rst_path) |
|
|
28 |
|
|
|
29 |
model_to_test = sys.argv[2] |
|
|
30 |
cur_fold = sys.argv[3] |
|
|
31 |
plane = sys.argv[4] |
|
|
32 |
im_z = int(sys.argv[5]) |
|
|
33 |
im_y = int(sys.argv[6]) |
|
|
34 |
im_x = int(sys.argv[7]) |
|
|
35 |
high_range = float(sys.argv[8]) |
|
|
36 |
low_range = float(sys.argv[9]) |
|
|
37 |
margin = int(sys.argv[10]) |
|
|
38 |
vis = sys.argv[11] |
|
|
39 |
|
|
|
40 |
# prediction of trained model |
|
|
41 |
pred_path = os.path.join(rst_path, "pred-%s/"%cur_fold) |
|
|
42 |
if not os.path.exists(pred_path): |
|
|
43 |
os.makedirs(pred_path) |
|
|
44 |
|
|
|
45 |
""" |
|
|
46 |
Dice Ceofficient and Cost functions for training |
|
|
47 |
""" |
|
|
48 |
smooth = 1. |
|
|
49 |
|
|
|
50 |
def dice_coef(y_true, y_pred): |
|
|
51 |
y_true_f = K.flatten(y_true) |
|
|
52 |
y_pred_f = K.flatten(y_pred) |
|
|
53 |
intersection = K.sum(y_true_f * y_pred_f) |
|
|
54 |
return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) |
|
|
55 |
|
|
|
56 |
def dice_coef_loss(y_true, y_pred): |
|
|
57 |
return -dice_coef(y_true, y_pred) |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def test(model_to_test, current_fold, plane, rst_dir, vis): |
|
|
61 |
print "-"*50 |
|
|
62 |
print "loading model ", model_to_test |
|
|
63 |
print "-"*50 |
|
|
64 |
|
|
|
65 |
model = load_model(model_path + model_to_test + '.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef':dice_coef}) |
|
|
66 |
volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() |
|
|
67 |
total = len(volume_list) |
|
|
68 |
|
|
|
69 |
dsc = np.zeros((total, 2)) |
|
|
70 |
|
|
|
71 |
# iterate all test cases |
|
|
72 |
for i in range(total): |
|
|
73 |
s = volume_list[i].split(' ') |
|
|
74 |
image = np.load(s[1]) |
|
|
75 |
label = np.load(s[2]) |
|
|
76 |
|
|
|
77 |
case_num = s[1].split("00")[1].split(".")[0] |
|
|
78 |
print "testing case: ", case_num |
|
|
79 |
|
|
|
80 |
image_ = np.transpose(image, (2, 0, 1)) |
|
|
81 |
label_ = np.transpose(label, (2, 0, 1)) |
|
|
82 |
|
|
|
83 |
# standardize test data |
|
|
84 |
image_[image_ < low_range] = low_range |
|
|
85 |
image_[image_ > high_range] = high_range |
|
|
86 |
image_ = (image_ - low_range) / float(high_range - low_range) |
|
|
87 |
|
|
|
88 |
# for creating final prediction visualization |
|
|
89 |
pred = np.zeros_like(image_) |
|
|
90 |
|
|
|
91 |
for sli in range(label_.shape[0]): |
|
|
92 |
try: |
|
|
93 |
# crop each slice according to smallest bounding box of each slice |
|
|
94 |
width = label_[sli].shape[0] |
|
|
95 |
height = label_[sli].shape[1] |
|
|
96 |
|
|
|
97 |
arr = np.nonzero(label_[sli]) |
|
|
98 |
|
|
|
99 |
if len(arr[0]) == 0: |
|
|
100 |
continue |
|
|
101 |
|
|
|
102 |
minA = min(arr[0]) |
|
|
103 |
maxA = max(arr[0]) |
|
|
104 |
minB = min(arr[1]) |
|
|
105 |
maxB = max(arr[1]) |
|
|
106 |
|
|
|
107 |
minAdiff = margin |
|
|
108 |
maxAdiff = margin |
|
|
109 |
minBdiff = margin |
|
|
110 |
maxBdiff = margin |
|
|
111 |
|
|
|
112 |
cropped = image_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \ |
|
|
113 |
max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)] |
|
|
114 |
cropped_mask = label_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \ |
|
|
115 |
max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)] |
|
|
116 |
|
|
|
117 |
image_padded_ = pad_2d(cropped, plane, 0, im_x, im_y, im_z) |
|
|
118 |
mask_padded_ = pad_2d(cropped_mask, plane, 0, im_x, im_y, im_z) |
|
|
119 |
|
|
|
120 |
image_padded_prep = preprocess_front(preprocess(image_padded_)) |
|
|
121 |
|
|
|
122 |
out_ori = (model.predict(image_padded_prep) > 0.5).astype(np.uint8) |
|
|
123 |
|
|
|
124 |
out = out_ori[:,0:cropped.shape[0], 0:cropped.shape[1],:].reshape(cropped.shape) |
|
|
125 |
pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] = out |
|
|
126 |
pred_vis = pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] |
|
|
127 |
|
|
|
128 |
if vis == "true": |
|
|
129 |
fig = plt.figure() |
|
|
130 |
ax = fig.add_subplot(1, 3, 1) |
|
|
131 |
ax.set_title("input test image") |
|
|
132 |
ax.imshow(cropped, cmap=plt.cm.gray) |
|
|
133 |
|
|
|
134 |
ax = fig.add_subplot(1, 3, 2) |
|
|
135 |
ax.set_title("prediction") |
|
|
136 |
ax.imshow(pred_vis, cmap=plt.cm.gray) |
|
|
137 |
|
|
|
138 |
ax = fig.add_subplot(1, 3, 3) |
|
|
139 |
ax.set_title("ground truth") |
|
|
140 |
ax.imshow(cropped_mask, cmap=plt.cm.gray) |
|
|
141 |
|
|
|
142 |
# plt.suptitle("slice %s"%sli) |
|
|
143 |
fig.canvas.set_window_title("slice %s"%sli) |
|
|
144 |
plt.axis('off') |
|
|
145 |
plt.show() |
|
|
146 |
|
|
|
147 |
except KeyboardInterrupt: |
|
|
148 |
print 'KeyboardInterrupt caught' |
|
|
149 |
raise ValueError("terminate because of keyboard interruption") |
|
|
150 |
|
|
|
151 |
# ------------ write out for visualization --------------- |
|
|
152 |
np.save(pred_path + case_num + ".npy", pred) # prediction made by the trained model |
|
|
153 |
|
|
|
154 |
# compute DSC |
|
|
155 |
cur_dsc, _, _, _ = DSC_computation(label_, pred) |
|
|
156 |
print cur_dsc |
|
|
157 |
|
|
|
158 |
dsc[i][0] = case_num |
|
|
159 |
dsc[i][1] = cur_dsc |
|
|
160 |
|
|
|
161 |
dsc_mean = np.mean(dsc[:,1]) |
|
|
162 |
dsc_std = np.std(dsc[:,1]) |
|
|
163 |
|
|
|
164 |
# record test dsc mean and standard deviation for each fold in the one file |
|
|
165 |
fd = open(rst_path + 'test_stats.csv','a+') |
|
|
166 |
fd.write("%s,%s,%s,%s\n"%(cur_fold, model_to_test, dsc_mean, dsc_std)) |
|
|
167 |
fd.close() |
|
|
168 |
|
|
|
169 |
print "---------------------------------" |
|
|
170 |
print "mean: ", dsc_mean |
|
|
171 |
print "std: ", dsc_std |
|
|
172 |
|
|
|
173 |
# record test result case by case |
|
|
174 |
np.savetxt(rst_path + model_to_test + ".csv", dsc, fmt = "%i, %.5f", delimiter=",", header="case_num,DSC") |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
if __name__ == "__main__": |
|
|
178 |
|
|
|
179 |
start_time = time.time() |
|
|
180 |
|
|
|
181 |
test(model_to_test, cur_fold, plane, rst_path, vis) |
|
|
182 |
|
|
|
183 |
print "-----------test done, total time used: %s ------------"% (time.time() - start_time) |