--- a +++ b/SegNet/SegNetCMR/inference.py @@ -0,0 +1,79 @@ +import tensorflow as tf +import tensorflow.contrib.slim as slim + +from .layers import unpool_with_argmax + +def inference_scope(is_training, batch_norm_decay=0.9): + with slim.arg_scope([slim.conv2d], + activation_fn=tf.nn.relu, + weights_initializer=tf.truncated_normal_initializer(stddev=0.01), + normalizer_fn=slim.batch_norm, + stride=1, + padding='SAME'): + + with slim.arg_scope([slim.batch_norm], + is_training=is_training, + decay=batch_norm_decay) as sc: + return sc + +def inference(images, class_inc_bg = None): + + tf.summary.image('input', images, max_outputs=3) + + with tf.variable_scope('pool1'): + net = slim.conv2d(images, 64, [3, 3], scope='conv1_1') + net = slim.conv2d(net, 64, [3, 3], scope='conv1_2') + net, arg1 = tf.nn.max_pool_with_argmax(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='maxpool1') + + with tf.variable_scope('pool2'): + net = slim.conv2d(net, 128, [3, 3], scope='conv2_1') + net = slim.conv2d(net, 128, [3, 3], scope='conv2_2') + net, arg2 = tf.nn.max_pool_with_argmax(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='maxpool2') + + with tf.variable_scope('pool3'): + net = slim.conv2d(net, 256, [3, 3], scope='conv3_1') + net = slim.conv2d(net, 256, [3, 3], scope='conv3_2') + net = slim.conv2d(net, 256, [3, 3], scope='conv3_3') + net, arg3 = tf.nn.max_pool_with_argmax(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='maxpool3') + + with tf.variable_scope('pool4'): + net = slim.conv2d(net, 512, [3, 3], scope='conv4_1') + net = slim.conv2d(net, 512, [3, 3], scope='conv4_2') + net = slim.conv2d(net, 512, [3, 3], scope='conv4_3') + net, arg4 = tf.nn.max_pool_with_argmax(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='maxpool4') + + with tf.variable_scope('pool5'): + net = slim.conv2d(net, 512, [3, 3], scope='conv5_1') + net = slim.conv2d(net, 512, [3, 3], scope='conv5_2') + net = slim.conv2d(net, 512, [3, 3], scope='conv5_3') + net, arg5 = tf.nn.max_pool_with_argmax(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='maxpool5') + + with tf.variable_scope('unpool5'): + net = unpool_with_argmax(net, arg5, name='maxunpool5') + net = slim.conv2d(net, 512, [3, 3], scope='uconv5_3') + net = slim.conv2d(net, 512, [3, 3], scope='uconv5_2') + net = slim.conv2d(net, 512, [3, 3], scope='uconv5_1') + + with tf.variable_scope('unpool4'): + net = unpool_with_argmax(net, arg4, name='maxunpool4') + net = slim.conv2d(net, 512, [3, 3], scope='uconv4_3') + net = slim.conv2d(net, 512, [3, 3], scope='uconv4_2') + net = slim.conv2d(net, 256, [3, 3], scope='uconv4_1') + + with tf.variable_scope('unpool3'): + net = unpool_with_argmax(net, arg3, name='maxunpool3') + net = slim.conv2d(net, 256, [3, 3], scope='uconv3_3') + net = slim.conv2d(net, 256, [3, 3], scope='uconv3_2') + net = slim.conv2d(net, 128, [3, 3], scope='uconv3_1') + + with tf.variable_scope('unpool2'): + net = unpool_with_argmax(net, arg2, name='maxunpool2') + net = slim.conv2d(net, 128, [3, 3], scope='uconv2_2') + net = slim.conv2d(net, 64, [3, 3], scope='uconv2_1') + + with tf.variable_scope('unpool1'): + net = unpool_with_argmax(net, arg1, name='maxunpool1') + net = slim.conv2d(net, 64, [3, 3], scope='uconv1_2') + logits = slim.conv2d(net, class_inc_bg, [3, 3], scope='uconv1_1') + + return logits