|
a |
|
b/BraTs18Challege/Vnet/model_vnet3d.py |
|
|
1 |
''' |
|
|
2 |
|
|
|
3 |
''' |
|
|
4 |
from Vnet.layer import (conv_bn_relu_drop, down_sampling, deconv_relu, crop_and_concat, resnet_Add, conv_sigmod, |
|
|
5 |
save_images) |
|
|
6 |
import tensorflow as tf |
|
|
7 |
import numpy as np |
|
|
8 |
import os |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
def _create_conv_net(X, image_z, image_width, image_height, image_channel, phase, drop, n_class=1): |
|
|
12 |
inputX = tf.reshape(X, [-1, image_z, image_width, image_height, image_channel]) # shape=(?, 32, 32, 1) |
|
|
13 |
# Vnet model |
|
|
14 |
# layer1->convolution |
|
|
15 |
layer0 = conv_bn_relu_drop(x=inputX, kernal=(3, 3, 3, image_channel, 16), phase=phase, drop=drop, |
|
|
16 |
scope='layer0') |
|
|
17 |
layer1 = conv_bn_relu_drop(x=layer0, kernal=(3, 3, 3, 16, 16), phase=phase, drop=drop, |
|
|
18 |
scope='layer1') |
|
|
19 |
layer1 = resnet_Add(x1=layer0, x2=layer1) |
|
|
20 |
# down sampling1 |
|
|
21 |
down1 = down_sampling(x=layer1, kernal=(3, 3, 3, 16, 32), phase=phase, drop=drop, scope='down1') |
|
|
22 |
# layer2->convolution |
|
|
23 |
layer2 = conv_bn_relu_drop(x=down1, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop, |
|
|
24 |
scope='layer2_1') |
|
|
25 |
layer2 = conv_bn_relu_drop(x=layer2, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop, |
|
|
26 |
scope='layer2_2') |
|
|
27 |
layer2 = resnet_Add(x1=down1, x2=layer2) |
|
|
28 |
# down sampling2 |
|
|
29 |
down2 = down_sampling(x=layer2, kernal=(3, 3, 3, 32, 64), phase=phase, drop=drop, scope='down2') |
|
|
30 |
# layer3->convolution |
|
|
31 |
layer3 = conv_bn_relu_drop(x=down2, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop, |
|
|
32 |
scope='layer3_1') |
|
|
33 |
layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop, |
|
|
34 |
scope='layer3_2') |
|
|
35 |
layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop, |
|
|
36 |
scope='layer3_3') |
|
|
37 |
layer3 = resnet_Add(x1=down2, x2=layer3) |
|
|
38 |
# down sampling3 |
|
|
39 |
down3 = down_sampling(x=layer3, kernal=(3, 3, 3, 64, 128), phase=phase, drop=drop, scope='down3') |
|
|
40 |
# layer4->convolution |
|
|
41 |
layer4 = conv_bn_relu_drop(x=down3, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop, |
|
|
42 |
scope='layer4_1') |
|
|
43 |
layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop, |
|
|
44 |
scope='layer4_2') |
|
|
45 |
layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop, |
|
|
46 |
scope='layer4_3') |
|
|
47 |
layer4 = resnet_Add(x1=down3, x2=layer4) |
|
|
48 |
# down sampling4 |
|
|
49 |
down4 = down_sampling(x=layer4, kernal=(3, 3, 3, 128, 256), phase=phase, drop=drop, scope='down4') |
|
|
50 |
# layer5->convolution |
|
|
51 |
layer5 = conv_bn_relu_drop(x=down4, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop, |
|
|
52 |
scope='layer5_1') |
|
|
53 |
layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop, |
|
|
54 |
scope='layer5_2') |
|
|
55 |
layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop, |
|
|
56 |
scope='layer5_3') |
|
|
57 |
layer5 = resnet_Add(x1=down4, x2=layer5) |
|
|
58 |
|
|
|
59 |
# layer9->deconvolution |
|
|
60 |
deconv1 = deconv_relu(x=layer5, kernal=(3, 3, 3, 128, 256), scope='deconv1') |
|
|
61 |
# layer8->convolution |
|
|
62 |
layer6 = crop_and_concat(layer4, deconv1) |
|
|
63 |
_, Z, H, W, _ = layer4.get_shape().as_list() |
|
|
64 |
layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 256, 128), image_z=Z, height=H, width=W, phase=phase, |
|
|
65 |
drop=drop, scope='layer6_1') |
|
|
66 |
layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase, |
|
|
67 |
drop=drop, scope='layer6_2') |
|
|
68 |
layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase, |
|
|
69 |
drop=drop, scope='layer6_3') |
|
|
70 |
layer6 = resnet_Add(x1=deconv1, x2=layer6) |
|
|
71 |
# layer9->deconvolution |
|
|
72 |
deconv2 = deconv_relu(x=layer6, kernal=(3, 3, 3, 64, 128), scope='deconv2') |
|
|
73 |
# layer8->convolution |
|
|
74 |
layer7 = crop_and_concat(layer3, deconv2) |
|
|
75 |
_, Z, H, W, _ = layer3.get_shape().as_list() |
|
|
76 |
layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 128, 64), image_z=Z, height=H, width=W, phase=phase, |
|
|
77 |
drop=drop, scope='layer7_1') |
|
|
78 |
layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 64, 64), image_z=Z, height=H, width=W, phase=phase, |
|
|
79 |
drop=drop, scope='layer7_2') |
|
|
80 |
layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 64, 64), image_z=Z, height=H, width=W, phase=phase, |
|
|
81 |
drop=drop, scope='layer7_3') |
|
|
82 |
layer7 = resnet_Add(x1=deconv2, x2=layer7) |
|
|
83 |
# layer9->deconvolution |
|
|
84 |
deconv3 = deconv_relu(x=layer7, kernal=(3, 3, 3, 32, 64), scope='deconv3') |
|
|
85 |
# layer8->convolution |
|
|
86 |
layer8 = crop_and_concat(layer2, deconv3) |
|
|
87 |
_, Z, H, W, _ = layer2.get_shape().as_list() |
|
|
88 |
layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 64, 32), image_z=Z, height=H, width=W, phase=phase, |
|
|
89 |
drop=drop, scope='layer8_1') |
|
|
90 |
layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase, |
|
|
91 |
drop=drop, scope='layer8_2') |
|
|
92 |
layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase, |
|
|
93 |
drop=drop, scope='layer8_3') |
|
|
94 |
layer8 = resnet_Add(x1=deconv3, x2=layer8) |
|
|
95 |
# layer9->deconvolution |
|
|
96 |
deconv4 = deconv_relu(x=layer8, kernal=(3, 3, 3, 16, 32), scope='deconv4') |
|
|
97 |
# layer8->convolution |
|
|
98 |
layer9 = crop_and_concat(layer1, deconv4) |
|
|
99 |
_, Z, H, W, _ = layer1.get_shape().as_list() |
|
|
100 |
layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 16), image_z=Z, height=H, width=W, phase=phase, |
|
|
101 |
drop=drop, scope='layer9_1') |
|
|
102 |
layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 16, 16), image_z=Z, height=H, width=W, phase=phase, |
|
|
103 |
drop=drop, scope='layer9_2') |
|
|
104 |
layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 16, 16), image_z=Z, height=H, width=W, phase=phase, |
|
|
105 |
drop=drop, scope='layer9_3') |
|
|
106 |
layer9 = resnet_Add(x1=deconv4, x2=layer9) |
|
|
107 |
# layer14->output |
|
|
108 |
output_map = conv_sigmod(x=layer9, kernal=(1, 1, 1, 16, n_class), scope='output') |
|
|
109 |
return output_map |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
# Serve data by batches |
|
|
113 |
def _next_batch(train_images, train_labels, batch_size, index_in_epoch): |
|
|
114 |
start = index_in_epoch |
|
|
115 |
index_in_epoch += batch_size |
|
|
116 |
|
|
|
117 |
num_examples = train_images.shape[0] |
|
|
118 |
# when all trainig data have been already used, it is reorder randomly |
|
|
119 |
if index_in_epoch > num_examples: |
|
|
120 |
# shuffle the data |
|
|
121 |
perm = np.arange(num_examples) |
|
|
122 |
np.random.shuffle(perm) |
|
|
123 |
train_images = train_images[perm] |
|
|
124 |
train_labels = train_labels[perm] |
|
|
125 |
# start next epoch |
|
|
126 |
start = 0 |
|
|
127 |
index_in_epoch = batch_size |
|
|
128 |
assert batch_size <= num_examples |
|
|
129 |
end = index_in_epoch |
|
|
130 |
return train_images[start:end], train_labels[start:end], index_in_epoch |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
class Vnet3dModule(object): |
|
|
134 |
""" |
|
|
135 |
A VNet3d implementation |
|
|
136 |
:param image_height: number of height in the input image |
|
|
137 |
:param image_width: number of width in the input image |
|
|
138 |
:param image_depth: number of depth in the input image |
|
|
139 |
:param channels: number of channels in the input image |
|
|
140 |
:param costname: name of the cost function.Default is "dice coefficient" |
|
|
141 |
""" |
|
|
142 |
|
|
|
143 |
def __init__(self, image_height, image_width, image_depth, channels=1, numclass=1, costname=("dice coefficient",), |
|
|
144 |
inference=False, model_path=None): |
|
|
145 |
self.image_width = image_width |
|
|
146 |
self.image_height = image_height |
|
|
147 |
self.image_depth = image_depth |
|
|
148 |
self.channels = channels |
|
|
149 |
self.numclass = numclass |
|
|
150 |
|
|
|
151 |
self.X = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width, |
|
|
152 |
self.channels]) |
|
|
153 |
self.Y_gt = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width, |
|
|
154 |
self.numclass]) |
|
|
155 |
self.lr = tf.placeholder('float') |
|
|
156 |
self.phase = tf.placeholder(tf.bool) |
|
|
157 |
self.drop = tf.placeholder('float') |
|
|
158 |
|
|
|
159 |
self.Y_pred = _create_conv_net(self.X, self.image_depth, self.image_width, self.image_height, self.channels, |
|
|
160 |
self.phase, self.drop, self.numclass) |
|
|
161 |
self.cost = self.__get_cost(self.Y_pred, self.Y_gt, costname[0]) |
|
|
162 |
self.accuracy = -self.cost |
|
|
163 |
|
|
|
164 |
if inference: |
|
|
165 |
init = tf.global_variables_initializer() |
|
|
166 |
saver = tf.train.Saver() |
|
|
167 |
self.sess = tf.InteractiveSession() |
|
|
168 |
self.sess.run(init) |
|
|
169 |
saver.restore(self.sess, model_path) |
|
|
170 |
|
|
|
171 |
def __get_cost(self, Y_pred, Y_gt, cost_name): |
|
|
172 |
Z, H, W, C = Y_gt.get_shape().as_list()[1:] |
|
|
173 |
if cost_name == "dice coefficient": |
|
|
174 |
smooth = 1e-5 |
|
|
175 |
pred_flat = tf.reshape(Y_pred, [-1, H * W * C * Z]) |
|
|
176 |
true_flat = tf.reshape(Y_gt, [-1, H * W * C * Z]) |
|
|
177 |
intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth |
|
|
178 |
denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth |
|
|
179 |
loss = -tf.reduce_mean(intersection / denominator) |
|
|
180 |
return loss |
|
|
181 |
|
|
|
182 |
def train(self, train_images, train_lanbels, model_path, logs_path, learning_rate, |
|
|
183 |
dropout_conv=0.8, train_epochs=5, batch_size=1, showwindow=[8, 8]): |
|
|
184 |
num_sample = 100 |
|
|
185 |
if not os.path.exists(logs_path): |
|
|
186 |
os.makedirs(logs_path) |
|
|
187 |
if not os.path.exists(logs_path + "model\\"): |
|
|
188 |
os.makedirs(logs_path + "model\\") |
|
|
189 |
model_path = logs_path + "model\\" + model_path |
|
|
190 |
train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost) |
|
|
191 |
|
|
|
192 |
init = tf.global_variables_initializer() |
|
|
193 |
saver = tf.train.Saver(tf.all_variables(), max_to_keep=10) |
|
|
194 |
|
|
|
195 |
tf.summary.scalar("loss", self.cost) |
|
|
196 |
tf.summary.scalar("accuracy", self.accuracy) |
|
|
197 |
merged_summary_op = tf.summary.merge_all() |
|
|
198 |
sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) |
|
|
199 |
summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph()) |
|
|
200 |
sess.run(init) |
|
|
201 |
|
|
|
202 |
if os.path.exists(model_path): |
|
|
203 |
saver.restore(sess, model_path) |
|
|
204 |
|
|
|
205 |
# load data and show result param |
|
|
206 |
DISPLAY_STEP = 1 |
|
|
207 |
num_sample_index_in_epoch = 0 |
|
|
208 |
index_in_epoch = 0 |
|
|
209 |
|
|
|
210 |
train_epochs = train_images.shape[0] * train_epochs |
|
|
211 |
|
|
|
212 |
subbatch_xs = np.empty((num_sample, self.image_depth, self.image_height, self.image_width, self.channels)) |
|
|
213 |
subbatch_ys = np.empty((num_sample, self.image_depth, self.image_height, self.image_width, self.numclass)) |
|
|
214 |
|
|
|
215 |
for i in range(train_epochs): |
|
|
216 |
# Extracting num_sample images and labels from given data |
|
|
217 |
if i % num_sample == 0 or i == 0: |
|
|
218 |
batch_xs_path, batch_ys_path, num_sample_index_in_epoch = _next_batch(train_images, train_lanbels, |
|
|
219 |
num_sample, |
|
|
220 |
num_sample_index_in_epoch) |
|
|
221 |
for num in range(len(batch_xs_path)): |
|
|
222 |
image = np.load(batch_xs_path[num]) |
|
|
223 |
label = np.load(batch_ys_path[num]) |
|
|
224 |
# prepare 3 model output |
|
|
225 |
batch_ys1 = label.copy() |
|
|
226 |
batch_ys1[label == 1.] = 1. |
|
|
227 |
batch_ys1[label != 1.] = 0. |
|
|
228 |
batch_ys2 = label.copy() |
|
|
229 |
batch_ys2[label == 2.] = 1. |
|
|
230 |
batch_ys2[label != 2.] = 0. |
|
|
231 |
batch_ys3 = label.copy() |
|
|
232 |
batch_ys3[label == 4.] = 1. |
|
|
233 |
batch_ys3[label != 4.] = 0. |
|
|
234 |
subbatch_xs[num, :, :, :, :] = np.reshape(image, |
|
|
235 |
(self.image_depth, self.image_height, self.image_width, |
|
|
236 |
self.channels)) |
|
|
237 |
label_ys = np.empty((self.image_depth, self.image_height, self.image_width, self.numclass)) |
|
|
238 |
label_ys[:, :, :, 0] = batch_ys1 |
|
|
239 |
label_ys[:, :, :, 1] = batch_ys2 |
|
|
240 |
label_ys[:, :, :, 2] = batch_ys3 |
|
|
241 |
subbatch_ys[num, :, :, :, :] = np.reshape(label_ys, |
|
|
242 |
(self.image_depth, self.image_height, self.image_width, |
|
|
243 |
self.numclass)) |
|
|
244 |
|
|
|
245 |
subbatch_xs = subbatch_xs.astype(np.float) |
|
|
246 |
subbatch_ys = subbatch_ys.astype(np.float) |
|
|
247 |
# get new batch |
|
|
248 |
batch_xs, batch_ys, index_in_epoch = _next_batch(subbatch_xs, subbatch_ys, batch_size, index_in_epoch) |
|
|
249 |
# check progress on every 1st,2nd,...,10th,20th,...,100th... step |
|
|
250 |
if i % DISPLAY_STEP == 0 or (i + 1) == train_epochs: |
|
|
251 |
train_loss, train_accuracy = sess.run( |
|
|
252 |
[self.cost, self.accuracy], feed_dict={self.X: batch_xs, |
|
|
253 |
self.Y_gt: batch_ys, |
|
|
254 |
self.lr: learning_rate, |
|
|
255 |
self.phase: 1, |
|
|
256 |
self.drop: dropout_conv}) |
|
|
257 |
print('epochs %d training_loss ,training_accuracy ''=> %.5f,%.5f ' % (i, train_loss, train_accuracy)) |
|
|
258 |
|
|
|
259 |
pred = sess.run(self.Y_pred, feed_dict={self.X: batch_xs, |
|
|
260 |
self.Y_gt: batch_ys, |
|
|
261 |
self.phase: 1, |
|
|
262 |
self.drop: 1}) |
|
|
263 |
gt = np.reshape(batch_ys[0], (self.image_depth, self.image_height, self.image_width, self.numclass)) |
|
|
264 |
gt1 = gt[:, :, :, 0] |
|
|
265 |
gt1 = np.reshape(gt1, (self.image_depth, self.image_height, self.image_width)) |
|
|
266 |
gt1 = gt1.astype(np.float) |
|
|
267 |
save_images(gt1, showwindow, path=logs_path + 'gt1_%d_epoch.png' % i) |
|
|
268 |
gt2 = gt[:, :, :, 1] |
|
|
269 |
gt2 = np.reshape(gt2, (self.image_depth, self.image_height, self.image_width)) |
|
|
270 |
gt2 = gt2.astype(np.float) |
|
|
271 |
save_images(gt2, showwindow, path=logs_path + 'gt2_%d_epoch.png' % i) |
|
|
272 |
gt3 = gt[:, :, :, 2] |
|
|
273 |
gt3 = np.reshape(gt3, (self.image_depth, self.image_height, self.image_width)) |
|
|
274 |
gt3 = gt3.astype(np.float) |
|
|
275 |
save_images(gt3, showwindow, path=logs_path + 'gt3_%d_epoch.png' % i) |
|
|
276 |
|
|
|
277 |
result = np.reshape(pred[0], (self.image_depth, self.image_height, self.image_width, self.numclass)) |
|
|
278 |
result1 = result[:, :, :, 0] |
|
|
279 |
result1 = np.reshape(result1, (self.image_depth, self.image_height, self.image_width)) |
|
|
280 |
result1 = result1.astype(np.float) |
|
|
281 |
save_images(result1, showwindow, path=logs_path + 'predict1_%d_epoch.png' % i) |
|
|
282 |
result2 = result[:, :, :, 1] |
|
|
283 |
result2 = np.reshape(result2, (self.image_depth, self.image_height, self.image_width)) |
|
|
284 |
result2 = result2.astype(np.float) |
|
|
285 |
save_images(result2, showwindow, path=logs_path + 'predict2_%d_epoch.png' % i) |
|
|
286 |
result3 = result[:, :, :, 2] |
|
|
287 |
result3 = np.reshape(result3, (self.image_depth, self.image_height, self.image_width)) |
|
|
288 |
result3 = result3.astype(np.float) |
|
|
289 |
save_images(result3, showwindow, path=logs_path + 'predict3_%d_epoch.png' % i) |
|
|
290 |
|
|
|
291 |
save_path = saver.save(sess, model_path, global_step=i) |
|
|
292 |
print("Model saved in file:", save_path) |
|
|
293 |
if i % (DISPLAY_STEP * 10) == 0 and i: |
|
|
294 |
DISPLAY_STEP *= 10 |
|
|
295 |
|
|
|
296 |
# train on batch |
|
|
297 |
_, summary = sess.run([train_op, merged_summary_op], feed_dict={self.X: batch_xs, |
|
|
298 |
self.Y_gt: batch_ys, |
|
|
299 |
self.lr: learning_rate, |
|
|
300 |
self.phase: 1, |
|
|
301 |
self.drop: dropout_conv}) |
|
|
302 |
summary_writer.add_summary(summary, i) |
|
|
303 |
summary_writer.close() |
|
|
304 |
|
|
|
305 |
save_path = saver.save(sess, model_path) |
|
|
306 |
print("Model saved in file:", save_path) |
|
|
307 |
|
|
|
308 |
def prediction(self, test_images): |
|
|
309 |
test_images = np.reshape(test_images, |
|
|
310 |
(test_images.shape[0], test_images.shape[1], test_images.shape[2], self.channels)) |
|
|
311 |
test_images = test_images.astype(np.float) |
|
|
312 |
y_dummy = np.zeros((test_images.shape[0], test_images.shape[1], test_images.shape[2], 3)) |
|
|
313 |
pred = self.sess.run(self.Y_pred, feed_dict={self.X: [test_images], self.Y_gt: [y_dummy], self.phase: 1, |
|
|
314 |
self.drop: 1}) |
|
|
315 |
result = pred.astype(np.float) * 255. |
|
|
316 |
result = np.clip(result, 0, 255).astype('uint8') |
|
|
317 |
result = np.reshape(result, (test_images.shape[0], test_images.shape[1], test_images.shape[2], self.numclass)) |
|
|
318 |
return result |