Diff of /testvis.py [000000] .. [48d89d]

Switch to unified view

a b/testvis.py
1
"""
2
This code is to test NN model and visualize output
3
"""
4
import numpy as np
5
import sys
6
import time
7
import matplotlib.pyplot as plt
8
9
from keras.models import Model, load_model
10
from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, ZeroPadding2D, BatchNormalization
11
from keras.optimizers import Adam, SGD
12
from keras.callbacks import ModelCheckpoint
13
from keras import backend as K
14
import tensorflow as tf
15
16
from data import load_train_data, load_test_data
17
from utils import *
18
19
K.set_image_data_format('channels_last')  # Tensorflow dimension ordering
20
21
data_path  = sys.argv[1] + "/"
22
model_path = data_path + "models/"
23
24
# dir for storing results that contains
25
rst_path = data_path + "test-records/"
26
if not os.path.exists(rst_path):
27
    os.makedirs(rst_path)
28
29
model_to_test = sys.argv[2]
30
cur_fold = sys.argv[3]
31
plane = sys.argv[4]
32
im_z = int(sys.argv[5])
33
im_y = int(sys.argv[6])
34
im_x = int(sys.argv[7])
35
high_range = float(sys.argv[8])
36
low_range = float(sys.argv[9])
37
margin = int(sys.argv[10])
38
vis = sys.argv[11]
39
40
# prediction of trained model
41
pred_path = os.path.join(rst_path, "pred-%s/"%cur_fold)
42
if not os.path.exists(pred_path):
43
    os.makedirs(pred_path)
44
45
"""
46
Dice Ceofficient and Cost functions for training
47
"""
48
smooth = 1.
49
50
def dice_coef(y_true, y_pred):
51
    y_true_f = K.flatten(y_true)
52
    y_pred_f = K.flatten(y_pred)
53
    intersection = K.sum(y_true_f * y_pred_f)
54
    return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
55
56
def dice_coef_loss(y_true, y_pred):
57
    return  -dice_coef(y_true, y_pred)
58
59
60
def test(model_to_test, current_fold, plane, rst_dir, vis):
61
    print "-"*50
62
    print "loading model ", model_to_test
63
    print "-"*50
64
65
    model = load_model(model_path + model_to_test + '.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef':dice_coef})
66
    volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines()
67
    total = len(volume_list)
68
69
    dsc = np.zeros((total, 2))
70
71
    # iterate all test cases
72
    for i in range(total):
73
        s = volume_list[i].split(' ')
74
        image = np.load(s[1])
75
        label = np.load(s[2])
76
77
        case_num = s[1].split("00")[1].split(".")[0]
78
        print "testing case: ", case_num
79
80
        image_ = np.transpose(image, (2, 0, 1))
81
        label_ = np.transpose(label, (2, 0, 1))
82
83
        # standardize test data
84
        image_[image_ < low_range] = low_range
85
        image_[image_ > high_range] = high_range
86
        image_ = (image_ - low_range) / float(high_range - low_range)
87
88
        # for creating final prediction visualization
89
        pred = np.zeros_like(image_)
90
91
        for sli in range(label_.shape[0]):
92
            try:
93
                # crop each slice according to smallest bounding box of each slice
94
                width = label_[sli].shape[0]
95
                height = label_[sli].shape[1]
96
97
                arr = np.nonzero(label_[sli])
98
99
                if len(arr[0]) == 0:
100
                    continue
101
102
                minA = min(arr[0])
103
                maxA = max(arr[0])
104
                minB = min(arr[1])
105
                maxB = max(arr[1])
106
107
                minAdiff = margin
108
                maxAdiff = margin
109
                minBdiff = margin
110
                maxBdiff = margin
111
112
                cropped = image_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \
113
                        max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)]
114
                cropped_mask = label_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \
115
                        max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)]
116
117
                image_padded_ = pad_2d(cropped, plane, 0, im_x, im_y, im_z)
118
                mask_padded_ = pad_2d(cropped_mask, plane, 0, im_x, im_y, im_z)
119
120
                image_padded_prep = preprocess_front(preprocess(image_padded_))
121
122
                out_ori = (model.predict(image_padded_prep) > 0.5).astype(np.uint8)
123
124
                out = out_ori[:,0:cropped.shape[0], 0:cropped.shape[1],:].reshape(cropped.shape)
125
                pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] = out
126
                pred_vis = pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)]
127
128
                if vis == "true":
129
                    fig = plt.figure()
130
                    ax = fig.add_subplot(1, 3, 1)
131
                    ax.set_title("input test image")
132
                    ax.imshow(cropped, cmap=plt.cm.gray)
133
134
                    ax = fig.add_subplot(1, 3, 2)
135
                    ax.set_title("prediction")
136
                    ax.imshow(pred_vis, cmap=plt.cm.gray)
137
138
                    ax = fig.add_subplot(1, 3, 3)
139
                    ax.set_title("ground truth")
140
                    ax.imshow(cropped_mask, cmap=plt.cm.gray)
141
142
                    # plt.suptitle("slice %s"%sli)
143
                    fig.canvas.set_window_title("slice %s"%sli)
144
                    plt.axis('off')
145
                    plt.show()
146
147
            except KeyboardInterrupt:
148
                print 'KeyboardInterrupt caught'
149
                raise ValueError("terminate because of keyboard interruption")
150
151
        # ------------ write out for visualization ---------------
152
        np.save(pred_path + case_num + ".npy", pred) # prediction made by the trained model
153
154
        # compute DSC
155
        cur_dsc, _, _, _ = DSC_computation(label_, pred)
156
        print cur_dsc
157
158
        dsc[i][0] = case_num
159
        dsc[i][1] = cur_dsc
160
161
    dsc_mean = np.mean(dsc[:,1])
162
    dsc_std = np.std(dsc[:,1])
163
164
    # record test dsc mean and standard deviation for each fold in the one file
165
    fd = open(rst_path + 'test_stats.csv','a+')
166
    fd.write("%s,%s,%s,%s\n"%(cur_fold, model_to_test, dsc_mean, dsc_std))
167
    fd.close()
168
169
    print "---------------------------------"
170
    print "mean: ", dsc_mean
171
    print "std: ", dsc_std
172
173
    # record test result case by case
174
    np.savetxt(rst_path + model_to_test + ".csv", dsc, fmt = "%i, %.5f", delimiter=",", header="case_num,DSC")
175
176
177
if __name__ == "__main__":
178
179
    start_time = time.time()
180
181
    test(model_to_test, cur_fold, plane, rst_path, vis)
182
183
    print "-----------test done, total time used: %s ------------"% (time.time() - start_time)