Diff of /saml_func.py [000000] .. [7b5b9f]

Switch to side-by-side view

--- a
+++ b/saml_func.py
@@ -0,0 +1,336 @@
+from __future__ import print_function
+import numpy as np
+import sys
+import tensorflow as tf
+from tensorflow.image import resize_images
+# try:
+#     import special_grads
+# except KeyError as e:
+#     print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr)
+
+from tensorflow.python.platform import flags
+from layer import conv_block, deconv_block, fc, max_pool, concat2d
+from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost
+
+class SAML:
+    def __init__(self, args):
+        """ Call construct_model_*() after initializing MASF"""
+        self.args = args
+
+        self.batch_size = args.meta_batch_size
+        self.test_batch_size = args.test_batch_size
+        self.volume_size = args.volume_size
+        self.n_class = args.n_class
+        self.compactness_loss_weight = args.compactness_loss_weight
+        self.smoothness_loss_weight = args.smoothness_loss_weight
+        self.margin = args.margin
+
+        self.forward = self.forward_unet
+        self.construct_weights = self.construct_unet_weights
+        self.seg_loss = _get_segmentation_cost
+        self.get_compactness_cost = _get_compactness_cost
+
+    def construct_model_train(self, prefix='metatrain_'):
+        # a: meta-train for inner update, b: meta-test for meta loss
+        self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
+        self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
+        self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
+        self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
+        self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
+        self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
+        self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
+        self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
+        self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1])
+        self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1])
+        self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving")
+
+
+        self.clip_value = self.args.gradients_clip_value
+        self.KEEP_PROB = tf.placeholder(tf.float32)
+
+        with tf.variable_scope('model', reuse=None) as training_scope:
+            if 'weights' in dir(self):
+                print('weights already defined')
+                training_scope.reuse_variables()
+                weights = self.weights
+            else:
+                # Define the weights
+                self.weights = weights = self.construct_weights()
+
+            def task_metalearn(inp, reuse=True):
+                # Function to perform meta learning update """
+                inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp
+
+                # Obtaining the conventional task loss on meta-train
+                task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode)
+                task_lossa = self.seg_loss(task_outputa, labela)
+                task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode)
+                task_lossa1 = self.seg_loss(task_outputa1, labela1)
+
+                ## perform inner update with plain gradient descent on meta-train
+                grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values()))
+                grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation
+                gradients = dict(zip(weights.keys(), grads))
+                # fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()]))
+                fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()]))
+
+                ## compute compactness loss
+                task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode)
+                task_lossb = self.seg_loss(task_outputb, labelb)
+                compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb)
+                compactness_loss_b = self.compactness_loss_weight * compactness_loss_b
+
+                # compute smoothness loss
+                _, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode)
+                coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings)
+                metric_embeddings = self.forward_metric_net(coutour_embeddings)
+
+                print (metric_label_group.shape)
+                print (metric_embeddings.shape)
+                smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin)
+                smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b
+                task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1]
+
+                return task_output
+
+            self.global_step = tf.Variable(0, trainable=False)
+            # self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
+            # self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
+            self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False)
+            self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False)
+            self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False)
+
+            input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group)
+            result = task_metalearn(inp=input_tensors)
+            self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result
+           
+        ## Performance & Optimization
+        if 'train' in prefix:
+            self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0
+            self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b
+
+            var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')]
+            var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')]
+
+            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+            with tf.control_dependencies(update_ops):
+                self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step)
+
+            optimizer = tf.train.AdamOptimizer(self.outer_lr)
+            gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor)
+
+            # observe stability of gradients for meta loss
+            # l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
+            # for grad, var in gvs:
+            #     tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad))
+            #     tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var))
+            #     tf.summary.histogram('gradients/' + var.name, var)
+            #     tf.summary.histogram("feature_extractor_var/" + var.name, var)
+
+            # gvs = [(grad, var) for grad, var in gvs]
+            gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs]
+            self.meta_train_op = optimizer.apply_gradients(gvs)
+
+            # for grad, var in gvs:
+            #     tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad))
+            #     tf.summary.histogram('gradients_clipped/' + var.name, var)
+
+            self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric)
+
+        ## Summaries
+        # scalar_summaries = []
+        # train_images = []
+        # val_images = []
+
+        tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a)
+        tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1)
+        tf.summary.scalar(prefix+'target_loss', self.seg_loss_b)
+        tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b)
+        tf.summary.scalar(prefix+'target_length', self.length)
+        tf.summary.scalar(prefix+'target_area', self.area)
+        tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3))
+        tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3))
+        tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3))
+        tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3))
+        tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3))
+        tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3))
+        tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3))
+
+    def extract_coutour_embedding(self, coutour, embeddings):
+
+        coutour_embeddings = coutour * embeddings
+        average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2])
+        # print (coutour.shape)
+        # print (embeddings.shape)
+        # print (coutour_embeddings.shape)
+        # print (average_embeddings.shape)
+        return average_embeddings
+
+    def construct_model_test(self, prefix='test'):
+        self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
+        self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
+
+        with tf.variable_scope('model', reuse=None) as testing_scope:
+            if 'weights' in dir(self):
+                testing_scope.reuse_variables()
+                weights = self.weights
+            else:
+                raise ValueError('Weights not initilized. Create training model before testing model')
+
+            outputs, mask, _ = self.forward(self.test_input, weights)
+            losses = self.seg_loss(outputs, self.test_label)
+            # self.pred_prob = tf.nn.softmax(outputs)
+            self.outputs = mask
+
+        self.test_loss = losses
+        # self.test_acc = accuracies
+
+    def forward_metric_net(self, x):
+
+        with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope:
+
+            w1 = tf.get_variable('w1', shape=[48,24])
+            b1 = tf.get_variable('b1', shape=[24])
+            out = fc(x, w1, b1, activation='leaky_relu')
+            w2 = tf.get_variable('w2', shape=[24,16])
+            b2 = tf.get_variable('b2', shape=[16])
+            out = fc(out, w2, b2, activation='leaky_relu')
+
+        return out
+
+    def construct_unet_weights(self):
+
+        weights = {}
+        conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
+
+        with tf.variable_scope('conv1') as scope:
+            weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer)
+            weights['conv11_biases'] = tf.get_variable('biases', [16])
+            weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer)
+            weights['conv12_biases'] = tf.get_variable('biases2', [16])
+
+        with tf.variable_scope('conv2') as scope:
+            weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer)
+            weights['conv21_biases'] = tf.get_variable('biases', [32])
+            weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer)
+            weights['conv22_biases'] = tf.get_variable('biases2', [32])
+        ## Network has downsample here
+
+        with tf.variable_scope('conv3') as scope:
+            weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer)
+            weights['conv31_biases'] = tf.get_variable('biases', [64])
+            weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
+            weights['conv32_biases'] = tf.get_variable('biases2', [64])
+
+        with tf.variable_scope('conv4') as scope:
+            weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer)
+            weights['conv41_biases'] = tf.get_variable('biases', [128])
+            weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
+            weights['conv42_biases'] = tf.get_variable('biases2', [128])
+        ## Network has downsample here
+
+        with tf.variable_scope('conv5') as scope:
+            weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer)
+            weights['conv51_biases'] = tf.get_variable('biases', [256])
+            weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer)
+            weights['conv52_biases'] = tf.get_variable('biases2', [256])
+
+        with tf.variable_scope('deconv6') as scope:
+            weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer)
+            weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer)
+            weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer)
+            weights['conv61_biases'] = tf.get_variable('biases', [128])
+            weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
+            weights['conv62_biases'] = tf.get_variable('biases2', [128])
+
+        with tf.variable_scope('deconv7') as scope:
+            weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer)
+            weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer)
+            weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer)
+            weights['conv71_biases'] = tf.get_variable('biases', [64])
+            weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
+            weights['conv72_biases'] = tf.get_variable('biases2', [64])
+
+        with tf.variable_scope('deconv8') as scope:
+            weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer)
+            weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer)
+            weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer)
+            weights['conv81_biases'] = tf.get_variable('biases', [32])
+            weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer)
+            weights['conv82_biases'] = tf.get_variable('biases2', [32])
+
+        with tf.variable_scope('deconv9') as scope:
+            weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer)
+            weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer)
+            weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer)
+            weights['conv91_biases'] = tf.get_variable('biases', [16])
+            weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer)
+            weights['conv92_biases'] = tf.get_variable('biases2', [16])
+
+        with tf.variable_scope('output') as scope:
+            weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer)
+            weights['output_biases'] = tf.get_variable('biases', [2])
+
+        return weights
+
+    def forward_unet(self, inp, weights, is_training=True):
+
+        self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training)
+        self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training)
+        self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID')
+        # 192x192x16
+        self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training)
+        self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training)
+        self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID')
+        # 96x96x32
+        self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training)
+        self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training)
+        self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID')
+        # 48x48x64
+        self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training)
+        self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training)
+        self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID')
+        # 24x24x128
+        self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training)
+        self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training)
+        # 24x24x256
+
+        ## add upsampling, meanwhile, channel number is reduced to half
+        self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training)
+        # 48x48x128
+        self.sum6 = concat2d(self.deconv6, self.deconv6)
+        self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training)
+        self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training)
+        # 48x48x128
+
+        self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training)
+        # 96x96x64
+        self.sum7 = concat2d(self.deconv7, self.deconv7)
+        self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training)
+        self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training)
+        # 96x96x64
+
+        self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training)
+        # 192x192x32
+        self.sum8 = concat2d(self.deconv8, self.deconv8)
+        self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training)
+        self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training)
+        self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False)
+        # 192x192x32
+
+        self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training)
+        # 384x384x16
+        self.sum9 = concat2d(self.deconv9, self.deconv9)
+        self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training)
+        self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training)
+        # 384x384x16
+
+        self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training)
+        #384x384x2
+
+        self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes]
+        self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h]
+
+        self.embeddings = concat2d(self.conv82_resize, self.conv92)
+
+        return self.pred_prob, self.pred_compact, self.embeddings