|
a |
|
b/RefineNet & SESNet/nets/model.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
from tensorflow.contrib import slim |
|
|
3 |
from nets import resnet_v1 |
|
|
4 |
from nets.convlstm_cell import ConvLSTMCell |
|
|
5 |
from utils.training import get_valid_logits_and_labels |
|
|
6 |
FLAGS = tf.app.flags.FLAGS |
|
|
7 |
|
|
|
8 |
def unpool(inputs,scale): |
|
|
9 |
return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*scale, tf.shape(inputs)[2]*scale]) |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def ResidualConvUnit(inputs,features=256,kernel_size=3): |
|
|
13 |
net=tf.nn.relu(inputs) |
|
|
14 |
net=slim.conv2d(net, features, kernel_size) |
|
|
15 |
net=tf.nn.relu(net) |
|
|
16 |
net=slim.conv2d(net,features,kernel_size) |
|
|
17 |
net=tf.add(net,inputs) |
|
|
18 |
return net |
|
|
19 |
|
|
|
20 |
def ChainedResidualPooling(inputs,features=256): |
|
|
21 |
net_relu=tf.nn.relu(inputs) |
|
|
22 |
net=slim.max_pool2d(net_relu, [5, 5],stride=1,padding='SAME') |
|
|
23 |
net=slim.conv2d(net,features,3) |
|
|
24 |
net_sum_1=tf.add(net,net_relu) |
|
|
25 |
|
|
|
26 |
net = slim.max_pool2d(net_relu, [5, 5], stride=1, padding='SAME') |
|
|
27 |
net = slim.conv2d(net, features, 3) |
|
|
28 |
net_sum_2=tf.add(net,net_sum_1) |
|
|
29 |
|
|
|
30 |
return net_sum_2 |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
def MultiResolutionFusion(high_inputs=None,low_inputs=None,features=256): |
|
|
34 |
|
|
|
35 |
if high_inputs is None:#refineNet block 4 |
|
|
36 |
rcu_low_1 = low_inputs[0] |
|
|
37 |
rcu_low_2 = low_inputs[1] |
|
|
38 |
|
|
|
39 |
rcu_low_1 = slim.conv2d(rcu_low_1, features, 3) |
|
|
40 |
rcu_low_2 = slim.conv2d(rcu_low_2, features, 3) |
|
|
41 |
|
|
|
42 |
return tf.add(rcu_low_1,rcu_low_2) |
|
|
43 |
|
|
|
44 |
else: |
|
|
45 |
rcu_low_1 = low_inputs[0] |
|
|
46 |
rcu_low_2 = low_inputs[1] |
|
|
47 |
|
|
|
48 |
rcu_low_1 = slim.conv2d(rcu_low_1, features, 3) |
|
|
49 |
rcu_low_2 = slim.conv2d(rcu_low_2, features, 3) |
|
|
50 |
|
|
|
51 |
rcu_low = tf.add(rcu_low_1,rcu_low_2) |
|
|
52 |
|
|
|
53 |
rcu_high_1 = high_inputs[0] |
|
|
54 |
rcu_high_2 = high_inputs[1] |
|
|
55 |
|
|
|
56 |
rcu_high_1 = unpool(slim.conv2d(rcu_high_1, features, 3),2) |
|
|
57 |
rcu_high_2 = unpool(slim.conv2d(rcu_high_2, features, 3),2) |
|
|
58 |
|
|
|
59 |
rcu_high = tf.add(rcu_high_1,rcu_high_2) |
|
|
60 |
|
|
|
61 |
return tf.add(rcu_low, rcu_high) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def RefineBlock(high_inputs=None,low_inputs=None): |
|
|
65 |
|
|
|
66 |
if high_inputs is None: # block 4 |
|
|
67 |
rcu_low_1= ResidualConvUnit(low_inputs, features=256) |
|
|
68 |
rcu_low_2 = ResidualConvUnit(low_inputs, features=256) |
|
|
69 |
rcu_low = [rcu_low_1, rcu_low_2] |
|
|
70 |
|
|
|
71 |
fuse = MultiResolutionFusion(high_inputs=None, low_inputs=rcu_low, features=256) |
|
|
72 |
fuse_pooling = ChainedResidualPooling(fuse, features=256) |
|
|
73 |
output = ResidualConvUnit(fuse_pooling, features=256) |
|
|
74 |
return output |
|
|
75 |
else: |
|
|
76 |
rcu_low_1 = ResidualConvUnit(low_inputs, features=256) |
|
|
77 |
rcu_low_2 = ResidualConvUnit(low_inputs, features=256) |
|
|
78 |
rcu_low = [rcu_low_1, rcu_low_2] |
|
|
79 |
|
|
|
80 |
rcu_high_1 = ResidualConvUnit(high_inputs, features=256) |
|
|
81 |
rcu_high_2 = ResidualConvUnit(high_inputs, features=256) |
|
|
82 |
rcu_high = [rcu_high_1, rcu_high_2] |
|
|
83 |
|
|
|
84 |
fuse = MultiResolutionFusion(rcu_high, rcu_low,features=256) |
|
|
85 |
fuse_pooling = ChainedResidualPooling(fuse, features=256) |
|
|
86 |
output = ResidualConvUnit(fuse_pooling, features=256) |
|
|
87 |
return output |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
|
|
|
91 |
def model(model_type, images, weight_decay=1e-5, is_training=True): |
|
|
92 |
images = mean_image_subtraction(images) |
|
|
93 |
|
|
|
94 |
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): |
|
|
95 |
logits, end_points = resnet_v1.resnet_v1_101(images, is_training=is_training, scope='resnet_v1_101') |
|
|
96 |
|
|
|
97 |
with tf.variable_scope('feature_fusion', values=[end_points.values]): |
|
|
98 |
batch_norm_params = {'decay': 0.997,'epsilon': 1e-5,'scale': True,'is_training': is_training} |
|
|
99 |
with slim.arg_scope([slim.conv2d], |
|
|
100 |
activation_fn=tf.nn.relu, |
|
|
101 |
normalizer_fn=slim.batch_norm, |
|
|
102 |
normalizer_params=batch_norm_params, |
|
|
103 |
weights_regularizer=slim.l2_regularizer(weight_decay)): |
|
|
104 |
|
|
|
105 |
f = [end_points['pool5'], end_points['pool4'], |
|
|
106 |
end_points['pool3'], end_points['pool2']] |
|
|
107 |
for i in range(4): |
|
|
108 |
print('Shape of f_{} {}'.format(i, f[i].shape)) |
|
|
109 |
|
|
|
110 |
g = [None, None, None, None] |
|
|
111 |
h = [None, None, None, None] |
|
|
112 |
|
|
|
113 |
for i in range(4): |
|
|
114 |
h[i]=slim.conv2d(f[i], 256, 1) |
|
|
115 |
for i in range(4): |
|
|
116 |
print('Shape of h_{} {}'.format(i, h[i].shape)) |
|
|
117 |
|
|
|
118 |
g[0]=RefineBlock(high_inputs=None,low_inputs=h[0]) |
|
|
119 |
g[1]=RefineBlock(g[0],h[1]) |
|
|
120 |
g[2]=RefineBlock(g[1],h[2]) |
|
|
121 |
g[3]=RefineBlock(g[2],h[3]) |
|
|
122 |
#g[3]=unpool(g[3],scale=4) |
|
|
123 |
|
|
|
124 |
output = g[3] |
|
|
125 |
|
|
|
126 |
if model_type == 'sesnet': |
|
|
127 |
in_shape = g[3].shape |
|
|
128 |
output = tf.expand_dims(g[3], axis=0) |
|
|
129 |
|
|
|
130 |
lstm_cell_1 = ConvLSTMCell([in_shape[1], in_shape[2]], in_shape[3] // 2, [3, 3]) |
|
|
131 |
lstm_cell_2 = ConvLSTMCell([in_shape[1], in_shape[2]], in_shape[3] // 4, [3, 3]) |
|
|
132 |
|
|
|
133 |
with tf.variable_scope('rnn_scope_0', reuse=tf.AUTO_REUSE): |
|
|
134 |
output0, _ = tf.nn.dynamic_rnn(lstm_cell_1, output, dtype=output.dtype) |
|
|
135 |
with tf.variable_scope('rnn_scope_1', reuse=tf.AUTO_REUSE): |
|
|
136 |
output1, _ = tf.nn.dynamic_rnn(lstm_cell_2, output0, dtype=output0.dtype) |
|
|
137 |
|
|
|
138 |
output = tf.squeeze(output1, axis=0) |
|
|
139 |
|
|
|
140 |
F_score = slim.conv2d(output, 2, 1, activation_fn=tf.nn.relu, normalizer_fn=None) |
|
|
141 |
|
|
|
142 |
return F_score |
|
|
143 |
|
|
|
144 |
|
|
|
145 |
def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): |
|
|
146 |
images=tf.to_float(images) |
|
|
147 |
num_channels = images.get_shape().as_list()[-1] |
|
|
148 |
if len(means) != num_channels: |
|
|
149 |
raise ValueError('len(means) must match the number of channels') |
|
|
150 |
channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) |
|
|
151 |
for i in range(num_channels): |
|
|
152 |
channels[i] -= means[i] |
|
|
153 |
return tf.concat(axis=3, values=channels) |
|
|
154 |
|
|
|
155 |
def loss(annotation_batch,upsampled_logits_batch,class_labels): |
|
|
156 |
valid_labels_batch_tensor, valid_logits_batch_tensor = get_valid_logits_and_labels( |
|
|
157 |
annotation_batch_tensor=annotation_batch, |
|
|
158 |
logits_batch_tensor=upsampled_logits_batch, |
|
|
159 |
class_labels=class_labels) |
|
|
160 |
|
|
|
161 |
valid_labels_batch_tensor = tf.argmax(valid_labels_batch_tensor, dimension=3) |
|
|
162 |
cross_entropies = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits_batch_tensor, |
|
|
163 |
labels=valid_labels_batch_tensor) |
|
|
164 |
|
|
|
165 |
cross_entropy_sum = tf.reduce_mean(cross_entropies) |
|
|
166 |
|
|
|
167 |
return cross_entropy_sum |
|
|
168 |
|