|
a |
|
b/saml_func.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
import numpy as np |
|
|
3 |
import sys |
|
|
4 |
import tensorflow as tf |
|
|
5 |
from tensorflow.image import resize_images |
|
|
6 |
# try: |
|
|
7 |
# import special_grads |
|
|
8 |
# except KeyError as e: |
|
|
9 |
# print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr) |
|
|
10 |
|
|
|
11 |
from tensorflow.python.platform import flags |
|
|
12 |
from layer import conv_block, deconv_block, fc, max_pool, concat2d |
|
|
13 |
from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost |
|
|
14 |
|
|
|
15 |
class SAML: |
|
|
16 |
def __init__(self, args): |
|
|
17 |
""" Call construct_model_*() after initializing MASF""" |
|
|
18 |
self.args = args |
|
|
19 |
|
|
|
20 |
self.batch_size = args.meta_batch_size |
|
|
21 |
self.test_batch_size = args.test_batch_size |
|
|
22 |
self.volume_size = args.volume_size |
|
|
23 |
self.n_class = args.n_class |
|
|
24 |
self.compactness_loss_weight = args.compactness_loss_weight |
|
|
25 |
self.smoothness_loss_weight = args.smoothness_loss_weight |
|
|
26 |
self.margin = args.margin |
|
|
27 |
|
|
|
28 |
self.forward = self.forward_unet |
|
|
29 |
self.construct_weights = self.construct_unet_weights |
|
|
30 |
self.seg_loss = _get_segmentation_cost |
|
|
31 |
self.get_compactness_cost = _get_compactness_cost |
|
|
32 |
|
|
|
33 |
def construct_model_train(self, prefix='metatrain_'): |
|
|
34 |
# a: meta-train for inner update, b: meta-test for meta loss |
|
|
35 |
self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) |
|
|
36 |
self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) |
|
|
37 |
self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) |
|
|
38 |
self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) |
|
|
39 |
self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) |
|
|
40 |
self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) |
|
|
41 |
self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) |
|
|
42 |
self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) |
|
|
43 |
self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1]) |
|
|
44 |
self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1]) |
|
|
45 |
self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving") |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
self.clip_value = self.args.gradients_clip_value |
|
|
49 |
self.KEEP_PROB = tf.placeholder(tf.float32) |
|
|
50 |
|
|
|
51 |
with tf.variable_scope('model', reuse=None) as training_scope: |
|
|
52 |
if 'weights' in dir(self): |
|
|
53 |
print('weights already defined') |
|
|
54 |
training_scope.reuse_variables() |
|
|
55 |
weights = self.weights |
|
|
56 |
else: |
|
|
57 |
# Define the weights |
|
|
58 |
self.weights = weights = self.construct_weights() |
|
|
59 |
|
|
|
60 |
def task_metalearn(inp, reuse=True): |
|
|
61 |
# Function to perform meta learning update """ |
|
|
62 |
inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp |
|
|
63 |
|
|
|
64 |
# Obtaining the conventional task loss on meta-train |
|
|
65 |
task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode) |
|
|
66 |
task_lossa = self.seg_loss(task_outputa, labela) |
|
|
67 |
task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode) |
|
|
68 |
task_lossa1 = self.seg_loss(task_outputa1, labela1) |
|
|
69 |
|
|
|
70 |
## perform inner update with plain gradient descent on meta-train |
|
|
71 |
grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values())) |
|
|
72 |
grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation |
|
|
73 |
gradients = dict(zip(weights.keys(), grads)) |
|
|
74 |
# fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()])) |
|
|
75 |
fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()])) |
|
|
76 |
|
|
|
77 |
## compute compactness loss |
|
|
78 |
task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode) |
|
|
79 |
task_lossb = self.seg_loss(task_outputb, labelb) |
|
|
80 |
compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb) |
|
|
81 |
compactness_loss_b = self.compactness_loss_weight * compactness_loss_b |
|
|
82 |
|
|
|
83 |
# compute smoothness loss |
|
|
84 |
_, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode) |
|
|
85 |
coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings) |
|
|
86 |
metric_embeddings = self.forward_metric_net(coutour_embeddings) |
|
|
87 |
|
|
|
88 |
print (metric_label_group.shape) |
|
|
89 |
print (metric_embeddings.shape) |
|
|
90 |
smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin) |
|
|
91 |
smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b |
|
|
92 |
task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1] |
|
|
93 |
|
|
|
94 |
return task_output |
|
|
95 |
|
|
|
96 |
self.global_step = tf.Variable(0, trainable=False) |
|
|
97 |
# self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate) |
|
|
98 |
# self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate) |
|
|
99 |
self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False) |
|
|
100 |
self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False) |
|
|
101 |
self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False) |
|
|
102 |
|
|
|
103 |
input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group) |
|
|
104 |
result = task_metalearn(inp=input_tensors) |
|
|
105 |
self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result |
|
|
106 |
|
|
|
107 |
## Performance & Optimization |
|
|
108 |
if 'train' in prefix: |
|
|
109 |
self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0 |
|
|
110 |
self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b |
|
|
111 |
|
|
|
112 |
var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')] |
|
|
113 |
var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')] |
|
|
114 |
|
|
|
115 |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
|
|
116 |
with tf.control_dependencies(update_ops): |
|
|
117 |
self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step) |
|
|
118 |
|
|
|
119 |
optimizer = tf.train.AdamOptimizer(self.outer_lr) |
|
|
120 |
gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor) |
|
|
121 |
|
|
|
122 |
# observe stability of gradients for meta loss |
|
|
123 |
# l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2))) |
|
|
124 |
# for grad, var in gvs: |
|
|
125 |
# tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad)) |
|
|
126 |
# tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var)) |
|
|
127 |
# tf.summary.histogram('gradients/' + var.name, var) |
|
|
128 |
# tf.summary.histogram("feature_extractor_var/" + var.name, var) |
|
|
129 |
|
|
|
130 |
# gvs = [(grad, var) for grad, var in gvs] |
|
|
131 |
gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs] |
|
|
132 |
self.meta_train_op = optimizer.apply_gradients(gvs) |
|
|
133 |
|
|
|
134 |
# for grad, var in gvs: |
|
|
135 |
# tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad)) |
|
|
136 |
# tf.summary.histogram('gradients_clipped/' + var.name, var) |
|
|
137 |
|
|
|
138 |
self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric) |
|
|
139 |
|
|
|
140 |
## Summaries |
|
|
141 |
# scalar_summaries = [] |
|
|
142 |
# train_images = [] |
|
|
143 |
# val_images = [] |
|
|
144 |
|
|
|
145 |
tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a) |
|
|
146 |
tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1) |
|
|
147 |
tf.summary.scalar(prefix+'target_loss', self.seg_loss_b) |
|
|
148 |
tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b) |
|
|
149 |
tf.summary.scalar(prefix+'target_length', self.length) |
|
|
150 |
tf.summary.scalar(prefix+'target_area', self.area) |
|
|
151 |
tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3)) |
|
|
152 |
tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3)) |
|
|
153 |
tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3)) |
|
|
154 |
tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3)) |
|
|
155 |
tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3)) |
|
|
156 |
tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3)) |
|
|
157 |
tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3)) |
|
|
158 |
|
|
|
159 |
def extract_coutour_embedding(self, coutour, embeddings): |
|
|
160 |
|
|
|
161 |
coutour_embeddings = coutour * embeddings |
|
|
162 |
average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2]) |
|
|
163 |
# print (coutour.shape) |
|
|
164 |
# print (embeddings.shape) |
|
|
165 |
# print (coutour_embeddings.shape) |
|
|
166 |
# print (average_embeddings.shape) |
|
|
167 |
return average_embeddings |
|
|
168 |
|
|
|
169 |
def construct_model_test(self, prefix='test'): |
|
|
170 |
self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) |
|
|
171 |
self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) |
|
|
172 |
|
|
|
173 |
with tf.variable_scope('model', reuse=None) as testing_scope: |
|
|
174 |
if 'weights' in dir(self): |
|
|
175 |
testing_scope.reuse_variables() |
|
|
176 |
weights = self.weights |
|
|
177 |
else: |
|
|
178 |
raise ValueError('Weights not initilized. Create training model before testing model') |
|
|
179 |
|
|
|
180 |
outputs, mask, _ = self.forward(self.test_input, weights) |
|
|
181 |
losses = self.seg_loss(outputs, self.test_label) |
|
|
182 |
# self.pred_prob = tf.nn.softmax(outputs) |
|
|
183 |
self.outputs = mask |
|
|
184 |
|
|
|
185 |
self.test_loss = losses |
|
|
186 |
# self.test_acc = accuracies |
|
|
187 |
|
|
|
188 |
def forward_metric_net(self, x): |
|
|
189 |
|
|
|
190 |
with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope: |
|
|
191 |
|
|
|
192 |
w1 = tf.get_variable('w1', shape=[48,24]) |
|
|
193 |
b1 = tf.get_variable('b1', shape=[24]) |
|
|
194 |
out = fc(x, w1, b1, activation='leaky_relu') |
|
|
195 |
w2 = tf.get_variable('w2', shape=[24,16]) |
|
|
196 |
b2 = tf.get_variable('b2', shape=[16]) |
|
|
197 |
out = fc(out, w2, b2, activation='leaky_relu') |
|
|
198 |
|
|
|
199 |
return out |
|
|
200 |
|
|
|
201 |
def construct_unet_weights(self): |
|
|
202 |
|
|
|
203 |
weights = {} |
|
|
204 |
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32) |
|
|
205 |
|
|
|
206 |
with tf.variable_scope('conv1') as scope: |
|
|
207 |
weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer) |
|
|
208 |
weights['conv11_biases'] = tf.get_variable('biases', [16]) |
|
|
209 |
weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer) |
|
|
210 |
weights['conv12_biases'] = tf.get_variable('biases2', [16]) |
|
|
211 |
|
|
|
212 |
with tf.variable_scope('conv2') as scope: |
|
|
213 |
weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer) |
|
|
214 |
weights['conv21_biases'] = tf.get_variable('biases', [32]) |
|
|
215 |
weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer) |
|
|
216 |
weights['conv22_biases'] = tf.get_variable('biases2', [32]) |
|
|
217 |
## Network has downsample here |
|
|
218 |
|
|
|
219 |
with tf.variable_scope('conv3') as scope: |
|
|
220 |
weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer) |
|
|
221 |
weights['conv31_biases'] = tf.get_variable('biases', [64]) |
|
|
222 |
weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer) |
|
|
223 |
weights['conv32_biases'] = tf.get_variable('biases2', [64]) |
|
|
224 |
|
|
|
225 |
with tf.variable_scope('conv4') as scope: |
|
|
226 |
weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer) |
|
|
227 |
weights['conv41_biases'] = tf.get_variable('biases', [128]) |
|
|
228 |
weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer) |
|
|
229 |
weights['conv42_biases'] = tf.get_variable('biases2', [128]) |
|
|
230 |
## Network has downsample here |
|
|
231 |
|
|
|
232 |
with tf.variable_scope('conv5') as scope: |
|
|
233 |
weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer) |
|
|
234 |
weights['conv51_biases'] = tf.get_variable('biases', [256]) |
|
|
235 |
weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer) |
|
|
236 |
weights['conv52_biases'] = tf.get_variable('biases2', [256]) |
|
|
237 |
|
|
|
238 |
with tf.variable_scope('deconv6') as scope: |
|
|
239 |
weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer) |
|
|
240 |
weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer) |
|
|
241 |
weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer) |
|
|
242 |
weights['conv61_biases'] = tf.get_variable('biases', [128]) |
|
|
243 |
weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer) |
|
|
244 |
weights['conv62_biases'] = tf.get_variable('biases2', [128]) |
|
|
245 |
|
|
|
246 |
with tf.variable_scope('deconv7') as scope: |
|
|
247 |
weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer) |
|
|
248 |
weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer) |
|
|
249 |
weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer) |
|
|
250 |
weights['conv71_biases'] = tf.get_variable('biases', [64]) |
|
|
251 |
weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer) |
|
|
252 |
weights['conv72_biases'] = tf.get_variable('biases2', [64]) |
|
|
253 |
|
|
|
254 |
with tf.variable_scope('deconv8') as scope: |
|
|
255 |
weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer) |
|
|
256 |
weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer) |
|
|
257 |
weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer) |
|
|
258 |
weights['conv81_biases'] = tf.get_variable('biases', [32]) |
|
|
259 |
weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer) |
|
|
260 |
weights['conv82_biases'] = tf.get_variable('biases2', [32]) |
|
|
261 |
|
|
|
262 |
with tf.variable_scope('deconv9') as scope: |
|
|
263 |
weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer) |
|
|
264 |
weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer) |
|
|
265 |
weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer) |
|
|
266 |
weights['conv91_biases'] = tf.get_variable('biases', [16]) |
|
|
267 |
weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer) |
|
|
268 |
weights['conv92_biases'] = tf.get_variable('biases2', [16]) |
|
|
269 |
|
|
|
270 |
with tf.variable_scope('output') as scope: |
|
|
271 |
weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer) |
|
|
272 |
weights['output_biases'] = tf.get_variable('biases', [2]) |
|
|
273 |
|
|
|
274 |
return weights |
|
|
275 |
|
|
|
276 |
def forward_unet(self, inp, weights, is_training=True): |
|
|
277 |
|
|
|
278 |
self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training) |
|
|
279 |
self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training) |
|
|
280 |
self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID') |
|
|
281 |
# 192x192x16 |
|
|
282 |
self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training) |
|
|
283 |
self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training) |
|
|
284 |
self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID') |
|
|
285 |
# 96x96x32 |
|
|
286 |
self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training) |
|
|
287 |
self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training) |
|
|
288 |
self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID') |
|
|
289 |
# 48x48x64 |
|
|
290 |
self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training) |
|
|
291 |
self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training) |
|
|
292 |
self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID') |
|
|
293 |
# 24x24x128 |
|
|
294 |
self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training) |
|
|
295 |
self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training) |
|
|
296 |
# 24x24x256 |
|
|
297 |
|
|
|
298 |
## add upsampling, meanwhile, channel number is reduced to half |
|
|
299 |
self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training) |
|
|
300 |
# 48x48x128 |
|
|
301 |
self.sum6 = concat2d(self.deconv6, self.deconv6) |
|
|
302 |
self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training) |
|
|
303 |
self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training) |
|
|
304 |
# 48x48x128 |
|
|
305 |
|
|
|
306 |
self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training) |
|
|
307 |
# 96x96x64 |
|
|
308 |
self.sum7 = concat2d(self.deconv7, self.deconv7) |
|
|
309 |
self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training) |
|
|
310 |
self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training) |
|
|
311 |
# 96x96x64 |
|
|
312 |
|
|
|
313 |
self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training) |
|
|
314 |
# 192x192x32 |
|
|
315 |
self.sum8 = concat2d(self.deconv8, self.deconv8) |
|
|
316 |
self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training) |
|
|
317 |
self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training) |
|
|
318 |
self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False) |
|
|
319 |
# 192x192x32 |
|
|
320 |
|
|
|
321 |
self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training) |
|
|
322 |
# 384x384x16 |
|
|
323 |
self.sum9 = concat2d(self.deconv9, self.deconv9) |
|
|
324 |
self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training) |
|
|
325 |
self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training) |
|
|
326 |
# 384x384x16 |
|
|
327 |
|
|
|
328 |
self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training) |
|
|
329 |
#384x384x2 |
|
|
330 |
|
|
|
331 |
self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes] |
|
|
332 |
self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h] |
|
|
333 |
|
|
|
334 |
self.embeddings = concat2d(self.conv82_resize, self.conv92) |
|
|
335 |
|
|
|
336 |
return self.pred_prob, self.pred_compact, self.embeddings |