a b/BraTs18Challege/Vnet/model_vnet3d.py
1
'''
2
3
'''
4
from Vnet.layer import (conv_bn_relu_drop, down_sampling, deconv_relu, crop_and_concat, resnet_Add, conv_sigmod,
5
                        save_images)
6
import tensorflow as tf
7
import numpy as np
8
import os
9
10
11
def _create_conv_net(X, image_z, image_width, image_height, image_channel, phase, drop, n_class=1):
12
    inputX = tf.reshape(X, [-1, image_z, image_width, image_height, image_channel])  # shape=(?, 32, 32, 1)
13
    # Vnet model
14
    # layer1->convolution
15
    layer0 = conv_bn_relu_drop(x=inputX, kernal=(3, 3, 3, image_channel, 16), phase=phase, drop=drop,
16
                               scope='layer0')
17
    layer1 = conv_bn_relu_drop(x=layer0, kernal=(3, 3, 3, 16, 16), phase=phase, drop=drop,
18
                               scope='layer1')
19
    layer1 = resnet_Add(x1=layer0, x2=layer1)
20
    # down sampling1
21
    down1 = down_sampling(x=layer1, kernal=(3, 3, 3, 16, 32), phase=phase, drop=drop, scope='down1')
22
    # layer2->convolution
23
    layer2 = conv_bn_relu_drop(x=down1, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop,
24
                               scope='layer2_1')
25
    layer2 = conv_bn_relu_drop(x=layer2, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop,
26
                               scope='layer2_2')
27
    layer2 = resnet_Add(x1=down1, x2=layer2)
28
    # down sampling2
29
    down2 = down_sampling(x=layer2, kernal=(3, 3, 3, 32, 64), phase=phase, drop=drop, scope='down2')
30
    # layer3->convolution
31
    layer3 = conv_bn_relu_drop(x=down2, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
32
                               scope='layer3_1')
33
    layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
34
                               scope='layer3_2')
35
    layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
36
                               scope='layer3_3')
37
    layer3 = resnet_Add(x1=down2, x2=layer3)
38
    # down sampling3
39
    down3 = down_sampling(x=layer3, kernal=(3, 3, 3, 64, 128), phase=phase, drop=drop, scope='down3')
40
    # layer4->convolution
41
    layer4 = conv_bn_relu_drop(x=down3, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
42
                               scope='layer4_1')
43
    layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
44
                               scope='layer4_2')
45
    layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
46
                               scope='layer4_3')
47
    layer4 = resnet_Add(x1=down3, x2=layer4)
48
    # down sampling4
49
    down4 = down_sampling(x=layer4, kernal=(3, 3, 3, 128, 256), phase=phase, drop=drop, scope='down4')
50
    # layer5->convolution
51
    layer5 = conv_bn_relu_drop(x=down4, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
52
                               scope='layer5_1')
53
    layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
54
                               scope='layer5_2')
55
    layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
56
                               scope='layer5_3')
57
    layer5 = resnet_Add(x1=down4, x2=layer5)
58
59
    # layer9->deconvolution
60
    deconv1 = deconv_relu(x=layer5, kernal=(3, 3, 3, 128, 256), scope='deconv1')
61
    # layer8->convolution
62
    layer6 = crop_and_concat(layer4, deconv1)
63
    _, Z, H, W, _ = layer4.get_shape().as_list()
64
    layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 256, 128), image_z=Z, height=H, width=W, phase=phase,
65
                               drop=drop, scope='layer6_1')
66
    layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase,
67
                               drop=drop, scope='layer6_2')
68
    layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase,
69
                               drop=drop, scope='layer6_3')
70
    layer6 = resnet_Add(x1=deconv1, x2=layer6)
71
    # layer9->deconvolution
72
    deconv2 = deconv_relu(x=layer6, kernal=(3, 3, 3, 64, 128), scope='deconv2')
73
    # layer8->convolution
74
    layer7 = crop_and_concat(layer3, deconv2)
75
    _, Z, H, W, _ = layer3.get_shape().as_list()
76
    layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 128, 64), image_z=Z, height=H, width=W, phase=phase,
77
                               drop=drop, scope='layer7_1')
78
    layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 64, 64), image_z=Z, height=H, width=W, phase=phase,
79
                               drop=drop, scope='layer7_2')
80
    layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 64, 64), image_z=Z, height=H, width=W, phase=phase,
81
                               drop=drop, scope='layer7_3')
82
    layer7 = resnet_Add(x1=deconv2, x2=layer7)
83
    # layer9->deconvolution
84
    deconv3 = deconv_relu(x=layer7, kernal=(3, 3, 3, 32, 64), scope='deconv3')
85
    # layer8->convolution
86
    layer8 = crop_and_concat(layer2, deconv3)
87
    _, Z, H, W, _ = layer2.get_shape().as_list()
88
    layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 64, 32), image_z=Z, height=H, width=W, phase=phase,
89
                               drop=drop, scope='layer8_1')
90
    layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
91
                               drop=drop, scope='layer8_2')
92
    layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
93
                               drop=drop, scope='layer8_3')
94
    layer8 = resnet_Add(x1=deconv3, x2=layer8)
95
    # layer9->deconvolution
96
    deconv4 = deconv_relu(x=layer8, kernal=(3, 3, 3, 16, 32), scope='deconv4')
97
    # layer8->convolution
98
    layer9 = crop_and_concat(layer1, deconv4)
99
    _, Z, H, W, _ = layer1.get_shape().as_list()
100
    layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 16), image_z=Z, height=H, width=W, phase=phase,
101
                               drop=drop, scope='layer9_1')
102
    layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 16, 16), image_z=Z, height=H, width=W, phase=phase,
103
                               drop=drop, scope='layer9_2')
104
    layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 16, 16), image_z=Z, height=H, width=W, phase=phase,
105
                               drop=drop, scope='layer9_3')
106
    layer9 = resnet_Add(x1=deconv4, x2=layer9)
107
    # layer14->output
108
    output_map = conv_sigmod(x=layer9, kernal=(1, 1, 1, 16, n_class), scope='output')
109
    return output_map
110
111
112
# Serve data by batches
113
def _next_batch(train_images, train_labels, batch_size, index_in_epoch):
114
    start = index_in_epoch
115
    index_in_epoch += batch_size
116
117
    num_examples = train_images.shape[0]
118
    # when all trainig data have been already used, it is reorder randomly
119
    if index_in_epoch > num_examples:
120
        # shuffle the data
121
        perm = np.arange(num_examples)
122
        np.random.shuffle(perm)
123
        train_images = train_images[perm]
124
        train_labels = train_labels[perm]
125
        # start next epoch
126
        start = 0
127
        index_in_epoch = batch_size
128
        assert batch_size <= num_examples
129
    end = index_in_epoch
130
    return train_images[start:end], train_labels[start:end], index_in_epoch
131
132
133
class Vnet3dModule(object):
134
    """
135
        A VNet3d implementation
136
        :param image_height: number of height in the input image
137
        :param image_width: number of width in the input image
138
        :param image_depth: number of depth in the input image
139
        :param channels: number of channels in the input image
140
        :param costname: name of the cost function.Default is "dice coefficient"
141
    """
142
143
    def __init__(self, image_height, image_width, image_depth, channels=1, numclass=1, costname=("dice coefficient",),
144
                 inference=False, model_path=None):
145
        self.image_width = image_width
146
        self.image_height = image_height
147
        self.image_depth = image_depth
148
        self.channels = channels
149
        self.numclass = numclass
150
151
        self.X = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
152
                                                self.channels])
153
        self.Y_gt = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
154
                                                   self.numclass])
155
        self.lr = tf.placeholder('float')
156
        self.phase = tf.placeholder(tf.bool)
157
        self.drop = tf.placeholder('float')
158
159
        self.Y_pred = _create_conv_net(self.X, self.image_depth, self.image_width, self.image_height, self.channels,
160
                                       self.phase, self.drop, self.numclass)
161
        self.cost = self.__get_cost(self.Y_pred, self.Y_gt, costname[0])
162
        self.accuracy = -self.cost
163
164
        if inference:
165
            init = tf.global_variables_initializer()
166
            saver = tf.train.Saver()
167
            self.sess = tf.InteractiveSession()
168
            self.sess.run(init)
169
            saver.restore(self.sess, model_path)
170
171
    def __get_cost(self, Y_pred, Y_gt, cost_name):
172
        Z, H, W, C = Y_gt.get_shape().as_list()[1:]
173
        if cost_name == "dice coefficient":
174
            smooth = 1e-5
175
            pred_flat = tf.reshape(Y_pred, [-1, H * W * C * Z])
176
            true_flat = tf.reshape(Y_gt, [-1, H * W * C * Z])
177
            intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth
178
            denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth
179
            loss = -tf.reduce_mean(intersection / denominator)
180
        return loss
181
182
    def train(self, train_images, train_lanbels, model_path, logs_path, learning_rate,
183
              dropout_conv=0.8, train_epochs=5, batch_size=1, showwindow=[8, 8]):
184
        num_sample = 100
185
        if not os.path.exists(logs_path):
186
            os.makedirs(logs_path)
187
        if not os.path.exists(logs_path + "model\\"):
188
            os.makedirs(logs_path + "model\\")
189
        model_path = logs_path + "model\\" + model_path
190
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)
191
192
        init = tf.global_variables_initializer()
193
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10)
194
195
        tf.summary.scalar("loss", self.cost)
196
        tf.summary.scalar("accuracy", self.accuracy)
197
        merged_summary_op = tf.summary.merge_all()
198
        sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
199
        summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
200
        sess.run(init)
201
202
        if os.path.exists(model_path):
203
            saver.restore(sess, model_path)
204
205
        # load data and show result param
206
        DISPLAY_STEP = 1
207
        num_sample_index_in_epoch = 0
208
        index_in_epoch = 0
209
210
        train_epochs = train_images.shape[0] * train_epochs
211
212
        subbatch_xs = np.empty((num_sample, self.image_depth, self.image_height, self.image_width, self.channels))
213
        subbatch_ys = np.empty((num_sample, self.image_depth, self.image_height, self.image_width, self.numclass))
214
215
        for i in range(train_epochs):
216
            # Extracting num_sample images and labels from given data
217
            if i % num_sample == 0 or i == 0:
218
                batch_xs_path, batch_ys_path, num_sample_index_in_epoch = _next_batch(train_images, train_lanbels,
219
                                                                                      num_sample,
220
                                                                                      num_sample_index_in_epoch)
221
                for num in range(len(batch_xs_path)):
222
                    image = np.load(batch_xs_path[num])
223
                    label = np.load(batch_ys_path[num])
224
                    # prepare 3 model output
225
                    batch_ys1 = label.copy()
226
                    batch_ys1[label == 1.] = 1.
227
                    batch_ys1[label != 1.] = 0.
228
                    batch_ys2 = label.copy()
229
                    batch_ys2[label == 2.] = 1.
230
                    batch_ys2[label != 2.] = 0.
231
                    batch_ys3 = label.copy()
232
                    batch_ys3[label == 4.] = 1.
233
                    batch_ys3[label != 4.] = 0.
234
                    subbatch_xs[num, :, :, :, :] = np.reshape(image,
235
                                                              (self.image_depth, self.image_height, self.image_width,
236
                                                               self.channels))
237
                    label_ys = np.empty((self.image_depth, self.image_height, self.image_width, self.numclass))
238
                    label_ys[:, :, :, 0] = batch_ys1
239
                    label_ys[:, :, :, 1] = batch_ys2
240
                    label_ys[:, :, :, 2] = batch_ys3
241
                    subbatch_ys[num, :, :, :, :] = np.reshape(label_ys,
242
                                                              (self.image_depth, self.image_height, self.image_width,
243
                                                               self.numclass))
244
245
                subbatch_xs = subbatch_xs.astype(np.float)
246
                subbatch_ys = subbatch_ys.astype(np.float)
247
            # get new batch
248
            batch_xs, batch_ys, index_in_epoch = _next_batch(subbatch_xs, subbatch_ys, batch_size, index_in_epoch)
249
            # check progress on every 1st,2nd,...,10th,20th,...,100th... step
250
            if i % DISPLAY_STEP == 0 or (i + 1) == train_epochs:
251
                train_loss, train_accuracy = sess.run(
252
                    [self.cost, self.accuracy], feed_dict={self.X: batch_xs,
253
                                                           self.Y_gt: batch_ys,
254
                                                           self.lr: learning_rate,
255
                                                           self.phase: 1,
256
                                                           self.drop: dropout_conv})
257
                print('epochs %d training_loss ,training_accuracy ''=> %.5f,%.5f ' % (i, train_loss, train_accuracy))
258
259
                pred = sess.run(self.Y_pred, feed_dict={self.X: batch_xs,
260
                                                        self.Y_gt: batch_ys,
261
                                                        self.phase: 1,
262
                                                        self.drop: 1})
263
                gt = np.reshape(batch_ys[0], (self.image_depth, self.image_height, self.image_width, self.numclass))
264
                gt1 = gt[:, :, :, 0]
265
                gt1 = np.reshape(gt1, (self.image_depth, self.image_height, self.image_width))
266
                gt1 = gt1.astype(np.float)
267
                save_images(gt1, showwindow, path=logs_path + 'gt1_%d_epoch.png' % i)
268
                gt2 = gt[:, :, :, 1]
269
                gt2 = np.reshape(gt2, (self.image_depth, self.image_height, self.image_width))
270
                gt2 = gt2.astype(np.float)
271
                save_images(gt2, showwindow, path=logs_path + 'gt2_%d_epoch.png' % i)
272
                gt3 = gt[:, :, :, 2]
273
                gt3 = np.reshape(gt3, (self.image_depth, self.image_height, self.image_width))
274
                gt3 = gt3.astype(np.float)
275
                save_images(gt3, showwindow, path=logs_path + 'gt3_%d_epoch.png' % i)
276
277
                result = np.reshape(pred[0], (self.image_depth, self.image_height, self.image_width, self.numclass))
278
                result1 = result[:, :, :, 0]
279
                result1 = np.reshape(result1, (self.image_depth, self.image_height, self.image_width))
280
                result1 = result1.astype(np.float)
281
                save_images(result1, showwindow, path=logs_path + 'predict1_%d_epoch.png' % i)
282
                result2 = result[:, :, :, 1]
283
                result2 = np.reshape(result2, (self.image_depth, self.image_height, self.image_width))
284
                result2 = result2.astype(np.float)
285
                save_images(result2, showwindow, path=logs_path + 'predict2_%d_epoch.png' % i)
286
                result3 = result[:, :, :, 2]
287
                result3 = np.reshape(result3, (self.image_depth, self.image_height, self.image_width))
288
                result3 = result3.astype(np.float)
289
                save_images(result3, showwindow, path=logs_path + 'predict3_%d_epoch.png' % i)
290
291
                save_path = saver.save(sess, model_path, global_step=i)
292
                print("Model saved in file:", save_path)
293
                if i % (DISPLAY_STEP * 10) == 0 and i:
294
                    DISPLAY_STEP *= 10
295
296
                    # train on batch
297
            _, summary = sess.run([train_op, merged_summary_op], feed_dict={self.X: batch_xs,
298
                                                                            self.Y_gt: batch_ys,
299
                                                                            self.lr: learning_rate,
300
                                                                            self.phase: 1,
301
                                                                            self.drop: dropout_conv})
302
            summary_writer.add_summary(summary, i)
303
        summary_writer.close()
304
305
        save_path = saver.save(sess, model_path)
306
        print("Model saved in file:", save_path)
307
308
    def prediction(self, test_images):
309
        test_images = np.reshape(test_images,
310
                                 (test_images.shape[0], test_images.shape[1], test_images.shape[2], self.channels))
311
        test_images = test_images.astype(np.float)
312
        y_dummy = np.zeros((test_images.shape[0], test_images.shape[1], test_images.shape[2], 3))
313
        pred = self.sess.run(self.Y_pred, feed_dict={self.X: [test_images], self.Y_gt: [y_dummy], self.phase: 1,
314
                                                     self.drop: 1})
315
        result = pred.astype(np.float) * 255.
316
        result = np.clip(result, 0, 255).astype('uint8')
317
        result = np.reshape(result, (test_images.shape[0], test_images.shape[1], test_images.shape[2], self.numclass))
318
        return result