--- a +++ b/saml_func.py @@ -0,0 +1,336 @@ +from __future__ import print_function +import numpy as np +import sys +import tensorflow as tf +from tensorflow.image import resize_images +# try: +# import special_grads +# except KeyError as e: +# print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr) + +from tensorflow.python.platform import flags +from layer import conv_block, deconv_block, fc, max_pool, concat2d +from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost + +class SAML: + def __init__(self, args): + """ Call construct_model_*() after initializing MASF""" + self.args = args + + self.batch_size = args.meta_batch_size + self.test_batch_size = args.test_batch_size + self.volume_size = args.volume_size + self.n_class = args.n_class + self.compactness_loss_weight = args.compactness_loss_weight + self.smoothness_loss_weight = args.smoothness_loss_weight + self.margin = args.margin + + self.forward = self.forward_unet + self.construct_weights = self.construct_unet_weights + self.seg_loss = _get_segmentation_cost + self.get_compactness_cost = _get_compactness_cost + + def construct_model_train(self, prefix='metatrain_'): + # a: meta-train for inner update, b: meta-test for meta loss + self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) + self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) + self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) + self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) + self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) + self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) + self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) + self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) + self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1]) + self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1]) + self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving") + + + self.clip_value = self.args.gradients_clip_value + self.KEEP_PROB = tf.placeholder(tf.float32) + + with tf.variable_scope('model', reuse=None) as training_scope: + if 'weights' in dir(self): + print('weights already defined') + training_scope.reuse_variables() + weights = self.weights + else: + # Define the weights + self.weights = weights = self.construct_weights() + + def task_metalearn(inp, reuse=True): + # Function to perform meta learning update """ + inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp + + # Obtaining the conventional task loss on meta-train + task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode) + task_lossa = self.seg_loss(task_outputa, labela) + task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode) + task_lossa1 = self.seg_loss(task_outputa1, labela1) + + ## perform inner update with plain gradient descent on meta-train + grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values())) + grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation + gradients = dict(zip(weights.keys(), grads)) + # fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()])) + fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()])) + + ## compute compactness loss + task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode) + task_lossb = self.seg_loss(task_outputb, labelb) + compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb) + compactness_loss_b = self.compactness_loss_weight * compactness_loss_b + + # compute smoothness loss + _, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode) + coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings) + metric_embeddings = self.forward_metric_net(coutour_embeddings) + + print (metric_label_group.shape) + print (metric_embeddings.shape) + smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin) + smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b + task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1] + + return task_output + + self.global_step = tf.Variable(0, trainable=False) + # self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate) + # self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate) + self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False) + self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False) + self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False) + + input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group) + result = task_metalearn(inp=input_tensors) + self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result + + ## Performance & Optimization + if 'train' in prefix: + self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0 + self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b + + var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')] + var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')] + + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step) + + optimizer = tf.train.AdamOptimizer(self.outer_lr) + gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor) + + # observe stability of gradients for meta loss + # l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2))) + # for grad, var in gvs: + # tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad)) + # tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var)) + # tf.summary.histogram('gradients/' + var.name, var) + # tf.summary.histogram("feature_extractor_var/" + var.name, var) + + # gvs = [(grad, var) for grad, var in gvs] + gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs] + self.meta_train_op = optimizer.apply_gradients(gvs) + + # for grad, var in gvs: + # tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad)) + # tf.summary.histogram('gradients_clipped/' + var.name, var) + + self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric) + + ## Summaries + # scalar_summaries = [] + # train_images = [] + # val_images = [] + + tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a) + tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1) + tf.summary.scalar(prefix+'target_loss', self.seg_loss_b) + tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b) + tf.summary.scalar(prefix+'target_length', self.length) + tf.summary.scalar(prefix+'target_area', self.area) + tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3)) + tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3)) + tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3)) + tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3)) + tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3)) + tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3)) + tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3)) + + def extract_coutour_embedding(self, coutour, embeddings): + + coutour_embeddings = coutour * embeddings + average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2]) + # print (coutour.shape) + # print (embeddings.shape) + # print (coutour_embeddings.shape) + # print (average_embeddings.shape) + return average_embeddings + + def construct_model_test(self, prefix='test'): + self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) + self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) + + with tf.variable_scope('model', reuse=None) as testing_scope: + if 'weights' in dir(self): + testing_scope.reuse_variables() + weights = self.weights + else: + raise ValueError('Weights not initilized. Create training model before testing model') + + outputs, mask, _ = self.forward(self.test_input, weights) + losses = self.seg_loss(outputs, self.test_label) + # self.pred_prob = tf.nn.softmax(outputs) + self.outputs = mask + + self.test_loss = losses + # self.test_acc = accuracies + + def forward_metric_net(self, x): + + with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope: + + w1 = tf.get_variable('w1', shape=[48,24]) + b1 = tf.get_variable('b1', shape=[24]) + out = fc(x, w1, b1, activation='leaky_relu') + w2 = tf.get_variable('w2', shape=[24,16]) + b2 = tf.get_variable('b2', shape=[16]) + out = fc(out, w2, b2, activation='leaky_relu') + + return out + + def construct_unet_weights(self): + + weights = {} + conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32) + + with tf.variable_scope('conv1') as scope: + weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer) + weights['conv11_biases'] = tf.get_variable('biases', [16]) + weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer) + weights['conv12_biases'] = tf.get_variable('biases2', [16]) + + with tf.variable_scope('conv2') as scope: + weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer) + weights['conv21_biases'] = tf.get_variable('biases', [32]) + weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer) + weights['conv22_biases'] = tf.get_variable('biases2', [32]) + ## Network has downsample here + + with tf.variable_scope('conv3') as scope: + weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer) + weights['conv31_biases'] = tf.get_variable('biases', [64]) + weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer) + weights['conv32_biases'] = tf.get_variable('biases2', [64]) + + with tf.variable_scope('conv4') as scope: + weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer) + weights['conv41_biases'] = tf.get_variable('biases', [128]) + weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer) + weights['conv42_biases'] = tf.get_variable('biases2', [128]) + ## Network has downsample here + + with tf.variable_scope('conv5') as scope: + weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer) + weights['conv51_biases'] = tf.get_variable('biases', [256]) + weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer) + weights['conv52_biases'] = tf.get_variable('biases2', [256]) + + with tf.variable_scope('deconv6') as scope: + weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer) + weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer) + weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer) + weights['conv61_biases'] = tf.get_variable('biases', [128]) + weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer) + weights['conv62_biases'] = tf.get_variable('biases2', [128]) + + with tf.variable_scope('deconv7') as scope: + weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer) + weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer) + weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer) + weights['conv71_biases'] = tf.get_variable('biases', [64]) + weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer) + weights['conv72_biases'] = tf.get_variable('biases2', [64]) + + with tf.variable_scope('deconv8') as scope: + weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer) + weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer) + weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer) + weights['conv81_biases'] = tf.get_variable('biases', [32]) + weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer) + weights['conv82_biases'] = tf.get_variable('biases2', [32]) + + with tf.variable_scope('deconv9') as scope: + weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer) + weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer) + weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer) + weights['conv91_biases'] = tf.get_variable('biases', [16]) + weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer) + weights['conv92_biases'] = tf.get_variable('biases2', [16]) + + with tf.variable_scope('output') as scope: + weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer) + weights['output_biases'] = tf.get_variable('biases', [2]) + + return weights + + def forward_unet(self, inp, weights, is_training=True): + + self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training) + self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training) + self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID') + # 192x192x16 + self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training) + self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training) + self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID') + # 96x96x32 + self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training) + self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training) + self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID') + # 48x48x64 + self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training) + self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training) + self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID') + # 24x24x128 + self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training) + self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training) + # 24x24x256 + + ## add upsampling, meanwhile, channel number is reduced to half + self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training) + # 48x48x128 + self.sum6 = concat2d(self.deconv6, self.deconv6) + self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training) + self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training) + # 48x48x128 + + self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training) + # 96x96x64 + self.sum7 = concat2d(self.deconv7, self.deconv7) + self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training) + self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training) + # 96x96x64 + + self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training) + # 192x192x32 + self.sum8 = concat2d(self.deconv8, self.deconv8) + self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training) + self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training) + self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False) + # 192x192x32 + + self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training) + # 384x384x16 + self.sum9 = concat2d(self.deconv9, self.deconv9) + self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training) + self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training) + # 384x384x16 + + self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training) + #384x384x2 + + self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes] + self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h] + + self.embeddings = concat2d(self.conv82_resize, self.conv92) + + return self.pred_prob, self.pred_compact, self.embeddings