Diff of /saml_func.py [000000] .. [7b5b9f]

Switch to unified view

a b/saml_func.py
1
from __future__ import print_function
2
import numpy as np
3
import sys
4
import tensorflow as tf
5
from tensorflow.image import resize_images
6
# try:
7
#     import special_grads
8
# except KeyError as e:
9
#     print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr)
10
11
from tensorflow.python.platform import flags
12
from layer import conv_block, deconv_block, fc, max_pool, concat2d
13
from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost
14
15
class SAML:
16
    def __init__(self, args):
17
        """ Call construct_model_*() after initializing MASF"""
18
        self.args = args
19
20
        self.batch_size = args.meta_batch_size
21
        self.test_batch_size = args.test_batch_size
22
        self.volume_size = args.volume_size
23
        self.n_class = args.n_class
24
        self.compactness_loss_weight = args.compactness_loss_weight
25
        self.smoothness_loss_weight = args.smoothness_loss_weight
26
        self.margin = args.margin
27
28
        self.forward = self.forward_unet
29
        self.construct_weights = self.construct_unet_weights
30
        self.seg_loss = _get_segmentation_cost
31
        self.get_compactness_cost = _get_compactness_cost
32
33
    def construct_model_train(self, prefix='metatrain_'):
34
        # a: meta-train for inner update, b: meta-test for meta loss
35
        self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
36
        self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
37
        self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
38
        self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
39
        self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
40
        self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
41
        self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
42
        self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
43
        self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1])
44
        self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1])
45
        self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving")
46
47
48
        self.clip_value = self.args.gradients_clip_value
49
        self.KEEP_PROB = tf.placeholder(tf.float32)
50
51
        with tf.variable_scope('model', reuse=None) as training_scope:
52
            if 'weights' in dir(self):
53
                print('weights already defined')
54
                training_scope.reuse_variables()
55
                weights = self.weights
56
            else:
57
                # Define the weights
58
                self.weights = weights = self.construct_weights()
59
60
            def task_metalearn(inp, reuse=True):
61
                # Function to perform meta learning update """
62
                inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp
63
64
                # Obtaining the conventional task loss on meta-train
65
                task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode)
66
                task_lossa = self.seg_loss(task_outputa, labela)
67
                task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode)
68
                task_lossa1 = self.seg_loss(task_outputa1, labela1)
69
70
                ## perform inner update with plain gradient descent on meta-train
71
                grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values()))
72
                grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation
73
                gradients = dict(zip(weights.keys(), grads))
74
                # fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()]))
75
                fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()]))
76
77
                ## compute compactness loss
78
                task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode)
79
                task_lossb = self.seg_loss(task_outputb, labelb)
80
                compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb)
81
                compactness_loss_b = self.compactness_loss_weight * compactness_loss_b
82
83
                # compute smoothness loss
84
                _, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode)
85
                coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings)
86
                metric_embeddings = self.forward_metric_net(coutour_embeddings)
87
88
                print (metric_label_group.shape)
89
                print (metric_embeddings.shape)
90
                smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin)
91
                smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b
92
                task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1]
93
94
                return task_output
95
96
            self.global_step = tf.Variable(0, trainable=False)
97
            # self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
98
            # self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
99
            self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False)
100
            self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False)
101
            self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False)
102
103
            input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group)
104
            result = task_metalearn(inp=input_tensors)
105
            self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result
106
           
107
        ## Performance & Optimization
108
        if 'train' in prefix:
109
            self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0
110
            self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b
111
112
            var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')]
113
            var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')]
114
115
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
116
            with tf.control_dependencies(update_ops):
117
                self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step)
118
119
            optimizer = tf.train.AdamOptimizer(self.outer_lr)
120
            gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor)
121
122
            # observe stability of gradients for meta loss
123
            # l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
124
            # for grad, var in gvs:
125
            #     tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad))
126
            #     tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var))
127
            #     tf.summary.histogram('gradients/' + var.name, var)
128
            #     tf.summary.histogram("feature_extractor_var/" + var.name, var)
129
130
            # gvs = [(grad, var) for grad, var in gvs]
131
            gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs]
132
            self.meta_train_op = optimizer.apply_gradients(gvs)
133
134
            # for grad, var in gvs:
135
            #     tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad))
136
            #     tf.summary.histogram('gradients_clipped/' + var.name, var)
137
138
            self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric)
139
140
        ## Summaries
141
        # scalar_summaries = []
142
        # train_images = []
143
        # val_images = []
144
145
        tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a)
146
        tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1)
147
        tf.summary.scalar(prefix+'target_loss', self.seg_loss_b)
148
        tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b)
149
        tf.summary.scalar(prefix+'target_length', self.length)
150
        tf.summary.scalar(prefix+'target_area', self.area)
151
        tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3))
152
        tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3))
153
        tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3))
154
        tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3))
155
        tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3))
156
        tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3))
157
        tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3))
158
159
    def extract_coutour_embedding(self, coutour, embeddings):
160
161
        coutour_embeddings = coutour * embeddings
162
        average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2])
163
        # print (coutour.shape)
164
        # print (embeddings.shape)
165
        # print (coutour_embeddings.shape)
166
        # print (average_embeddings.shape)
167
        return average_embeddings
168
169
    def construct_model_test(self, prefix='test'):
170
        self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
171
        self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
172
173
        with tf.variable_scope('model', reuse=None) as testing_scope:
174
            if 'weights' in dir(self):
175
                testing_scope.reuse_variables()
176
                weights = self.weights
177
            else:
178
                raise ValueError('Weights not initilized. Create training model before testing model')
179
180
            outputs, mask, _ = self.forward(self.test_input, weights)
181
            losses = self.seg_loss(outputs, self.test_label)
182
            # self.pred_prob = tf.nn.softmax(outputs)
183
            self.outputs = mask
184
185
        self.test_loss = losses
186
        # self.test_acc = accuracies
187
188
    def forward_metric_net(self, x):
189
190
        with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope:
191
192
            w1 = tf.get_variable('w1', shape=[48,24])
193
            b1 = tf.get_variable('b1', shape=[24])
194
            out = fc(x, w1, b1, activation='leaky_relu')
195
            w2 = tf.get_variable('w2', shape=[24,16])
196
            b2 = tf.get_variable('b2', shape=[16])
197
            out = fc(out, w2, b2, activation='leaky_relu')
198
199
        return out
200
201
    def construct_unet_weights(self):
202
203
        weights = {}
204
        conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
205
206
        with tf.variable_scope('conv1') as scope:
207
            weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer)
208
            weights['conv11_biases'] = tf.get_variable('biases', [16])
209
            weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer)
210
            weights['conv12_biases'] = tf.get_variable('biases2', [16])
211
212
        with tf.variable_scope('conv2') as scope:
213
            weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer)
214
            weights['conv21_biases'] = tf.get_variable('biases', [32])
215
            weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer)
216
            weights['conv22_biases'] = tf.get_variable('biases2', [32])
217
        ## Network has downsample here
218
219
        with tf.variable_scope('conv3') as scope:
220
            weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer)
221
            weights['conv31_biases'] = tf.get_variable('biases', [64])
222
            weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
223
            weights['conv32_biases'] = tf.get_variable('biases2', [64])
224
225
        with tf.variable_scope('conv4') as scope:
226
            weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer)
227
            weights['conv41_biases'] = tf.get_variable('biases', [128])
228
            weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
229
            weights['conv42_biases'] = tf.get_variable('biases2', [128])
230
        ## Network has downsample here
231
232
        with tf.variable_scope('conv5') as scope:
233
            weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer)
234
            weights['conv51_biases'] = tf.get_variable('biases', [256])
235
            weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer)
236
            weights['conv52_biases'] = tf.get_variable('biases2', [256])
237
238
        with tf.variable_scope('deconv6') as scope:
239
            weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer)
240
            weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer)
241
            weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer)
242
            weights['conv61_biases'] = tf.get_variable('biases', [128])
243
            weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
244
            weights['conv62_biases'] = tf.get_variable('biases2', [128])
245
246
        with tf.variable_scope('deconv7') as scope:
247
            weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer)
248
            weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer)
249
            weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer)
250
            weights['conv71_biases'] = tf.get_variable('biases', [64])
251
            weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
252
            weights['conv72_biases'] = tf.get_variable('biases2', [64])
253
254
        with tf.variable_scope('deconv8') as scope:
255
            weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer)
256
            weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer)
257
            weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer)
258
            weights['conv81_biases'] = tf.get_variable('biases', [32])
259
            weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer)
260
            weights['conv82_biases'] = tf.get_variable('biases2', [32])
261
262
        with tf.variable_scope('deconv9') as scope:
263
            weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer)
264
            weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer)
265
            weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer)
266
            weights['conv91_biases'] = tf.get_variable('biases', [16])
267
            weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer)
268
            weights['conv92_biases'] = tf.get_variable('biases2', [16])
269
270
        with tf.variable_scope('output') as scope:
271
            weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer)
272
            weights['output_biases'] = tf.get_variable('biases', [2])
273
274
        return weights
275
276
    def forward_unet(self, inp, weights, is_training=True):
277
278
        self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training)
279
        self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training)
280
        self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID')
281
        # 192x192x16
282
        self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training)
283
        self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training)
284
        self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID')
285
        # 96x96x32
286
        self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training)
287
        self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training)
288
        self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID')
289
        # 48x48x64
290
        self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training)
291
        self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training)
292
        self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID')
293
        # 24x24x128
294
        self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training)
295
        self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training)
296
        # 24x24x256
297
298
        ## add upsampling, meanwhile, channel number is reduced to half
299
        self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training)
300
        # 48x48x128
301
        self.sum6 = concat2d(self.deconv6, self.deconv6)
302
        self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training)
303
        self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training)
304
        # 48x48x128
305
306
        self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training)
307
        # 96x96x64
308
        self.sum7 = concat2d(self.deconv7, self.deconv7)
309
        self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training)
310
        self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training)
311
        # 96x96x64
312
313
        self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training)
314
        # 192x192x32
315
        self.sum8 = concat2d(self.deconv8, self.deconv8)
316
        self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training)
317
        self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training)
318
        self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False)
319
        # 192x192x32
320
321
        self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training)
322
        # 384x384x16
323
        self.sum9 = concat2d(self.deconv9, self.deconv9)
324
        self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training)
325
        self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training)
326
        # 384x384x16
327
328
        self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training)
329
        #384x384x2
330
331
        self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes]
332
        self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h]
333
334
        self.embeddings = concat2d(self.conv82_resize, self.conv92)
335
336
        return self.pred_prob, self.pred_compact, self.embeddings