a b/RefineNet & SESNet/nets/model.py
1
import tensorflow as tf
2
from tensorflow.contrib import slim
3
from nets import resnet_v1
4
from nets.convlstm_cell import ConvLSTMCell
5
from utils.training import get_valid_logits_and_labels
6
FLAGS = tf.app.flags.FLAGS
7
8
def unpool(inputs,scale):
9
    return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*scale,  tf.shape(inputs)[2]*scale])
10
11
12
def ResidualConvUnit(inputs,features=256,kernel_size=3):
13
    net=tf.nn.relu(inputs)
14
    net=slim.conv2d(net, features, kernel_size)
15
    net=tf.nn.relu(net)
16
    net=slim.conv2d(net,features,kernel_size)
17
    net=tf.add(net,inputs)
18
    return net
19
20
def ChainedResidualPooling(inputs,features=256):
21
    net_relu=tf.nn.relu(inputs)
22
    net=slim.max_pool2d(net_relu, [5, 5],stride=1,padding='SAME')
23
    net=slim.conv2d(net,features,3)
24
    net_sum_1=tf.add(net,net_relu)
25
26
    net = slim.max_pool2d(net_relu, [5, 5], stride=1, padding='SAME')
27
    net = slim.conv2d(net, features, 3)
28
    net_sum_2=tf.add(net,net_sum_1)
29
30
    return net_sum_2
31
32
33
def MultiResolutionFusion(high_inputs=None,low_inputs=None,features=256):
34
35
    if high_inputs is None:#refineNet block 4
36
        rcu_low_1 = low_inputs[0]
37
        rcu_low_2 = low_inputs[1]
38
39
        rcu_low_1 = slim.conv2d(rcu_low_1, features, 3)
40
        rcu_low_2 = slim.conv2d(rcu_low_2, features, 3)
41
42
        return tf.add(rcu_low_1,rcu_low_2)
43
44
    else:
45
        rcu_low_1 = low_inputs[0]
46
        rcu_low_2 = low_inputs[1]
47
48
        rcu_low_1 = slim.conv2d(rcu_low_1, features, 3)
49
        rcu_low_2 = slim.conv2d(rcu_low_2, features, 3)
50
51
        rcu_low = tf.add(rcu_low_1,rcu_low_2)
52
53
        rcu_high_1 = high_inputs[0]
54
        rcu_high_2 = high_inputs[1]
55
56
        rcu_high_1 = unpool(slim.conv2d(rcu_high_1, features, 3),2)
57
        rcu_high_2 = unpool(slim.conv2d(rcu_high_2, features, 3),2)
58
59
        rcu_high = tf.add(rcu_high_1,rcu_high_2)
60
61
        return tf.add(rcu_low, rcu_high)
62
63
64
def RefineBlock(high_inputs=None,low_inputs=None):
65
66
    if high_inputs is None: # block 4
67
        rcu_low_1= ResidualConvUnit(low_inputs, features=256)
68
        rcu_low_2 = ResidualConvUnit(low_inputs, features=256)
69
        rcu_low = [rcu_low_1, rcu_low_2]
70
71
        fuse = MultiResolutionFusion(high_inputs=None, low_inputs=rcu_low, features=256)
72
        fuse_pooling = ChainedResidualPooling(fuse, features=256)
73
        output = ResidualConvUnit(fuse_pooling, features=256)
74
        return output
75
    else:
76
        rcu_low_1 = ResidualConvUnit(low_inputs, features=256)
77
        rcu_low_2 = ResidualConvUnit(low_inputs, features=256)
78
        rcu_low = [rcu_low_1, rcu_low_2]
79
80
        rcu_high_1 = ResidualConvUnit(high_inputs, features=256)
81
        rcu_high_2 = ResidualConvUnit(high_inputs, features=256)
82
        rcu_high = [rcu_high_1, rcu_high_2]
83
84
        fuse = MultiResolutionFusion(rcu_high, rcu_low,features=256)
85
        fuse_pooling = ChainedResidualPooling(fuse, features=256)
86
        output = ResidualConvUnit(fuse_pooling, features=256)
87
        return output
88
89
90
91
def model(model_type, images, weight_decay=1e-5, is_training=True):
92
    images = mean_image_subtraction(images)
93
94
    with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
95
        logits, end_points = resnet_v1.resnet_v1_101(images, is_training=is_training, scope='resnet_v1_101')
96
97
    with tf.variable_scope('feature_fusion', values=[end_points.values]):
98
        batch_norm_params = {'decay': 0.997,'epsilon': 1e-5,'scale': True,'is_training': is_training}
99
        with slim.arg_scope([slim.conv2d],
100
                            activation_fn=tf.nn.relu,
101
                            normalizer_fn=slim.batch_norm,
102
                            normalizer_params=batch_norm_params,
103
                            weights_regularizer=slim.l2_regularizer(weight_decay)):
104
105
            f = [end_points['pool5'], end_points['pool4'],
106
                 end_points['pool3'], end_points['pool2']]
107
            for i in range(4):
108
                print('Shape of f_{} {}'.format(i, f[i].shape))
109
110
            g = [None, None, None, None]
111
            h = [None, None, None, None]
112
113
            for i in range(4):
114
                h[i]=slim.conv2d(f[i], 256, 1)
115
            for i in range(4):
116
                print('Shape of h_{} {}'.format(i, h[i].shape))
117
118
            g[0]=RefineBlock(high_inputs=None,low_inputs=h[0])
119
            g[1]=RefineBlock(g[0],h[1])
120
            g[2]=RefineBlock(g[1],h[2])
121
            g[3]=RefineBlock(g[2],h[3])
122
            #g[3]=unpool(g[3],scale=4)
123
124
            output = g[3]
125
126
            if model_type == 'sesnet':
127
                in_shape = g[3].shape
128
                output = tf.expand_dims(g[3], axis=0)
129
130
                lstm_cell_1 = ConvLSTMCell([in_shape[1], in_shape[2]], in_shape[3] // 2, [3, 3])
131
                lstm_cell_2 = ConvLSTMCell([in_shape[1], in_shape[2]], in_shape[3] // 4, [3, 3])
132
133
                with tf.variable_scope('rnn_scope_0', reuse=tf.AUTO_REUSE):
134
                    output0, _ = tf.nn.dynamic_rnn(lstm_cell_1, output, dtype=output.dtype)
135
                with tf.variable_scope('rnn_scope_1', reuse=tf.AUTO_REUSE):
136
                    output1, _ = tf.nn.dynamic_rnn(lstm_cell_2, output0, dtype=output0.dtype)
137
138
                output = tf.squeeze(output1, axis=0)
139
140
            F_score = slim.conv2d(output, 2, 1, activation_fn=tf.nn.relu, normalizer_fn=None)
141
142
    return F_score
143
144
145
def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]):
146
    images=tf.to_float(images)
147
    num_channels = images.get_shape().as_list()[-1]
148
    if len(means) != num_channels:
149
        raise ValueError('len(means) must match the number of channels')
150
    channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images)
151
    for i in range(num_channels):
152
        channels[i] -= means[i]
153
    return tf.concat(axis=3, values=channels)
154
155
def loss(annotation_batch,upsampled_logits_batch,class_labels):
156
    valid_labels_batch_tensor, valid_logits_batch_tensor = get_valid_logits_and_labels(
157
        annotation_batch_tensor=annotation_batch,
158
        logits_batch_tensor=upsampled_logits_batch,
159
        class_labels=class_labels)
160
161
    valid_labels_batch_tensor = tf.argmax(valid_labels_batch_tensor, dimension=3)
162
    cross_entropies = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits_batch_tensor,
163
                                                              labels=valid_labels_batch_tensor)
164
165
    cross_entropy_sum = tf.reduce_mean(cross_entropies)
166
167
    return cross_entropy_sum
168