a b/predict.py
1
import tensorflow as tf
2
import DataGenerator
3
4
batch_size = 2
5
is_training = tf.placeholder(tf.bool)
6
7
# data iterator
8
valTransforms = [
9
    DataGenerator.Normalization()
10
    ]
11
12
ValDataset = DataGenerator.DataGenerator(
13
    data_dir='val-data',
14
    transforms=valTransforms,
15
    train=False
16
    )
17
18
valDataset = ValDataset.get_dataset()
19
valDataset = valDataset.batch(batch_size)
20
21
iterator = tf.data.Iterator.from_structure(valDataset.output_types, valDataset.output_shapes)
22
23
validation_init_op = iterator.make_initializer(valDataset)
24
next_item = iterator.get_next()
25
26
# Convolution function
27
def conv3d(x, no_of_input_channels, no_of_filters, filter_size, strides, padding, name):
28
    with tf.variable_scope(name) as scope:
29
        
30
        initializer = tf.variance_scaling_initializer()
31
        
32
        filter_size.extend([no_of_input_channels, no_of_filters])
33
        weights = tf.Variable(initializer(filter_size), name='weights')
34
        biases = tf.Variable(initializer([no_of_filters]), name='biases')
35
        conv = tf.nn.conv3d(x, weights, strides=strides, padding=padding, name=name)
36
        conv += biases
37
                
38
        return conv
39
40
# Transposed convolution function
41
def upsamp(x, no_of_kernels, name):
42
    with tf.variable_scope(name) as scope:
43
        upsamp = tf.layers.conv3d_transpose(x, no_of_kernels, [2,2,2], 2, padding='VALID', use_bias=True, reuse=tf.AUTO_REUSE)
44
        return upsamp
45
    
46
# PReLu function
47
def prelu(x, scope=None):
48
    with tf.variable_scope(name_or_scope=scope, default_name="prelu", reuse=tf.AUTO_REUSE):
49
        alpha = tf.get_variable("prelu", shape=x.get_shape()[-1], dtype=x.dtype, initializer=tf.constant_initializer(0.1))
50
        prelu_out = tf.maximum(0.0, x) + alpha * tf.minimum(0.0, x)
51
        return prelu_out
52
    
53
# model architecture
54
def graph_encoder(x):
55
        
56
    fine_grained_features = {}
57
    
58
    conv1 = conv3d(x,1,16,[3,3,3],[1,1,1,1,1],'SAME','Conv1_1')
59
    conv1 = conv3d(conv1,16,16,[3,3,3],[1,1,1,1,1],'SAME','Conv1_2')
60
    conv1 = tf.layers.batch_normalization(conv1, training=is_training)
61
    conv1 = prelu(conv1,'prelu1')
62
    
63
    res1 = tf.add(x,conv1)
64
    fine_grained_features['res1'] = res1
65
    
66
    down1 = conv3d(res1,16,32,[2,2,2],[1,2,2,2,1],'VALID','DownSampling1')
67
    down1 = prelu(down1,'down_prelu1')
68
    
69
    conv2 = conv3d(down1,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv2_1')
70
    conv2 = conv3d(conv2,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv2_2')
71
    conv2= tf.layers.batch_normalization(conv2, training=is_training)
72
    conv2 = prelu(conv2,'prelu2')
73
    
74
    conv3 = conv3d(conv2,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv3_1')
75
    conv3 = conv3d(conv3,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv3_2')
76
    conv3 = tf.layers.batch_normalization(conv3, training=is_training)
77
    conv3 = prelu(conv3,'prelu3')
78
    
79
    res2 = tf.add(down1,conv3)
80
    fine_grained_features['res2'] = res2
81
82
    down2 = conv3d(res2,32,64,[2,2,2],[1,2,2,2,1],'VALID','DownSampling2')
83
    down2 = prelu(down2,'down_prelu2')
84
    
85
    conv4 = conv3d(down2,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv4_1')
86
    conv4 = conv3d(conv4,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv4_2')
87
    conv4 = tf.layers.batch_normalization(conv4, training=is_training)
88
    conv4 = prelu(conv4,'prelu4')
89
    
90
    conv5 = conv3d(conv4,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv5_1')
91
    conv5 = conv3d(conv5,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv5_2')
92
    conv5 = tf.layers.batch_normalization(conv5, training=is_training)
93
    conv5 = prelu(conv5,'prelu5')
94
    
95
    conv6 = conv3d(conv5,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv6_1')
96
    conv6 = conv3d(conv6,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv6_2')
97
    conv6 = tf.layers.batch_normalization(conv6, training=is_training)
98
    conv6 = prelu(conv6,'prelu6')
99
    
100
    res3 = tf.add(down2,conv6)
101
    fine_grained_features['res3'] = res3
102
103
    down3 = conv3d(res3,64,128,[2,2,2],[1,2,2,2,1],'VALID','DownSampling3')
104
    down3 = prelu(down3,'down_prelu3')
105
    
106
    conv7 = conv3d(down3,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv7_1')
107
    conv7 = conv3d(conv7,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv7_2')
108
    conv7 = tf.layers.batch_normalization(conv7, training=is_training)
109
    conv7 = prelu(conv7,'prelu7')
110
    
111
    conv8 = conv3d(conv7,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv8_1')
112
    conv8 = conv3d(conv8,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv8_2')
113
    conv8 = tf.layers.batch_normalization(conv8, training=is_training)
114
    conv8 = prelu(conv8,'prelu8')
115
    
116
    conv9 = conv3d(conv8,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv9_1')
117
    conv9 = conv3d(conv9,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv9_2')
118
    conv9 = tf.layers.batch_normalization(conv9, training=is_training)
119
    conv9 = prelu(conv9,'prelu9')
120
    
121
    res4 = tf.add(down3,conv9)
122
    fine_grained_features['res4'] = res4
123
124
    down4 = conv3d(res4,128,256,[2,2,2],[1,2,2,2,1],'VALID','DownSampling4')
125
    down4 = prelu(down4,'down_prelu4')
126
    
127
    conv10 = conv3d(down4,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv10_1')
128
    conv10 = conv3d(conv10,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv10_2')
129
    conv10 = tf.layers.batch_normalization(conv10, training=is_training)
130
    conv10 = prelu(conv10,'prelu10')
131
    
132
    conv11 = conv3d(conv10,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv11_1')
133
    conv11 = conv3d(conv11,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv11_2')
134
    conv11 = tf.layers.batch_normalization(conv11, training=is_training)
135
    conv11 = prelu(conv11,'prelu11')
136
    
137
    conv12 = conv3d(conv11,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv12_1')
138
    conv12 = conv3d(conv12,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv12_2')
139
    conv12 = tf.layers.batch_normalization(conv12, training=is_training)
140
    conv12 = prelu(conv12,'prelu12')
141
    
142
    res5 = tf.add(down4,conv12)
143
    fine_grained_features['res5'] = res5
144
    
145
    return fine_grained_features
146
147
def graph_decoder(features):
148
        
149
    inp = features['res5']
150
    
151
    upsamp1 = upsamp(inp,128,'Upsampling1')
152
    upsamp1 = prelu(upsamp1,'prelu_upsamp1')
153
    
154
    concat1 = tf.concat([upsamp1,features['res4']],axis=4)
155
    
156
    conv13 = conv3d(concat1,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv13_1')
157
    conv13 = conv3d(conv13,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv13_2')
158
    conv13 = tf.layers.batch_normalization(conv13, training=is_training)
159
    conv13 = prelu(conv13,'prelu13')
160
    
161
    conv14 = conv3d(conv13,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv14_1')
162
    conv14 = conv3d(conv14,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv14_2')
163
    conv14 = tf.layers.batch_normalization(conv14, training=is_training)
164
    conv14 = prelu(conv14,'prelu14')
165
    
166
    conv15 = conv3d(conv14,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv15_1')
167
    conv15 = conv3d(conv15,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv15_2')
168
    conv15 = tf.layers.batch_normalization(conv15, training=is_training)
169
    conv15 = prelu(conv15,'prelu15')
170
    
171
    res6 = tf.add(concat1,conv15)
172
    
173
    upsamp2 = upsamp(res6,64,'Upsampling2')
174
    upsamp2 = prelu(upsamp2,'prelu_upsamp2')
175
    
176
    concat2 = tf.concat([upsamp2,features['res3']],axis=4)
177
    
178
    conv16 = conv3d(concat2,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv16_1')
179
    conv16 = conv3d(conv16,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv16_2')
180
    conv16 = tf.layers.batch_normalization(conv16, training=is_training)
181
    conv16 = prelu(conv16,'prelu16')
182
    
183
    conv17 = conv3d(conv16,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv17_1')
184
    conv17 = conv3d(conv17,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv17_2')
185
    conv17 = tf.layers.batch_normalization(conv17, training=is_training)
186
    conv17 = prelu(conv17,'prelu17')
187
    
188
    conv18 = conv3d(conv17,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv18_1')
189
    conv18 = conv3d(conv18,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv18_2')
190
    conv18 = tf.layers.batch_normalization(conv18, training=is_training)
191
    conv18 = prelu(conv18,'prelu18')
192
    
193
    res7 = tf.add(concat2,conv18)
194
    
195
    upsamp3 = upsamp(res7,32,'Upsampling3')
196
    upsamp3 = prelu(upsamp3,'prelu_upsamp3')
197
    
198
    concat3 = tf.concat([upsamp3,features['res2']],axis=4)
199
    
200
    conv19 = conv3d(concat3,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv19_1')
201
    conv19 = conv3d(conv19,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv19_2')
202
    conv19 = tf.layers.batch_normalization(conv19, training=is_training)
203
    conv19 = prelu(conv19,'prelu19')
204
    
205
    conv20 = conv3d(conv19,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv20_1')
206
    conv20 = conv3d(conv20,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv20_2')
207
    conv20 = tf.layers.batch_normalization(conv20, training=is_training)
208
    conv20 = prelu(conv20,'prelu20')
209
    
210
    res8 = tf.add(concat3,conv20)
211
    
212
    upsamp4 = upsamp(res8,16,'Upsampling4')
213
    upsamp4 = prelu(upsamp4,'prelu_upsamp4')
214
    
215
    concat4 = tf.concat([upsamp4,features['res1']],axis=4)
216
    
217
    conv21 = conv3d(concat4,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv21_1')
218
    conv21 = conv3d(conv21,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv21_2')
219
    conv21 = tf.layers.batch_normalization(conv21, training=is_training)
220
    conv21 = prelu(conv21,'prelu21')
221
    
222
    res9 = tf.add(concat4,conv21)
223
    
224
    conv22 = conv3d(res9,32,1,[1,1,1],[1,1,1,1,1],'SAME','Conv22')
225
    conv22 = tf.nn.sigmoid(conv22,'sigmoid')
226
    return conv22
227
228
# forward propagation
229
def model_fn():
230
    
231
    features, labels = next_item
232
        
233
    features = tf.reshape(features, [-1, 128, 128, 64, 1])
234
    
235
    encoded = graph_encoder(features)
236
    decoded = graph_decoder(encoded)
237
238
    decoded = tf.reshape(decoded, [-1, 128, 128, 64])
239
    
240
    return decoded
241
242
def predict(init_epoch=0):
243
    
244
    predictions = []
245
    decoded = model_fn()
246
247
    with tf.Session() as sess:
248
        sess.run(tf.global_variables_initializer())
249
        
250
        saver = tf.train.Saver()
251
        saver = tf.train.import_meta_graph('/temp/weights_epoch_{0}.ckpt.meta'.format(init_epoch))
252
        saver.restore(sess, '/temp/weights_epoch_{0}.ckpt'.format(init_epoch))
253
254
        sess.run([validation_init_op])
255
256
        while(True):
257
            try:
258
                pred = sess.run([decoded], feed_dict={is_training: False})
259
                predictions.append(pred)
260
            except tf.errors.OutOfRangeError:
261
                return predictions
262
263
if __name__ == '__main__':
264
    init_epoch = 5000                    # last epoch
265
    predictions = predict(init_epoch)