|
a |
|
b/Segmentation/model/Hundred_Layer_Tiramisu.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import tensorflow.keras.layers as tfkl |
|
|
3 |
|
|
|
4 |
'''The implementation of the 100 layer Tiramisu Network follows |
|
|
5 |
directly from the publication found at https://arxiv.org/pdf/1611.09326.pdf''' |
|
|
6 |
|
|
|
7 |
class Hundred_Layer_Tiramisu(tf.keras.Model): |
|
|
8 |
def __init__(self, |
|
|
9 |
growth_rate, |
|
|
10 |
layers_per_block, |
|
|
11 |
num_channels, |
|
|
12 |
num_classes, |
|
|
13 |
kernel_size=(3, 3), |
|
|
14 |
pool_size=(2, 2), |
|
|
15 |
nonlinearity='relu', |
|
|
16 |
dropout_rate=0.2, |
|
|
17 |
strides=(2, 2), |
|
|
18 |
padding='same', |
|
|
19 |
use_dropout=False, |
|
|
20 |
use_concat=True, |
|
|
21 |
**kwargs): |
|
|
22 |
|
|
|
23 |
super(Hundred_Layer_Tiramisu, self).__init__(**kwargs) |
|
|
24 |
|
|
|
25 |
self.growth_rate = growth_rate |
|
|
26 |
self.layers_per_block = layers_per_block |
|
|
27 |
self.num_channels = num_channels |
|
|
28 |
self.num_classes = num_classes |
|
|
29 |
self.kernel_size = kernel_size |
|
|
30 |
self.pool_size = pool_size |
|
|
31 |
self.nonlinearity = nonlinearity |
|
|
32 |
self.dropout_rate = dropout_rate |
|
|
33 |
self.strides = strides |
|
|
34 |
self.padding = padding |
|
|
35 |
self.use_dropout = use_dropout |
|
|
36 |
self.use_concat = use_concat |
|
|
37 |
|
|
|
38 |
self.conv_3x3 = tfkl.Conv2D(self.num_channels, |
|
|
39 |
kernel_size, |
|
|
40 |
padding='same') |
|
|
41 |
self.dense_block_list = [] |
|
|
42 |
self.up_transition_list = [] |
|
|
43 |
|
|
|
44 |
self.conv_1x1 = tfkl.Conv2D(filters=num_classes, |
|
|
45 |
kernel_size=(1, 1), |
|
|
46 |
padding='same') |
|
|
47 |
|
|
|
48 |
layers_counter = 0 |
|
|
49 |
num_filters = num_channels |
|
|
50 |
|
|
|
51 |
print(len(self.layers_per_block)) |
|
|
52 |
|
|
|
53 |
for idx in range(0, len(self.layers_per_block)): |
|
|
54 |
print(idx) |
|
|
55 |
num_conv_layers = layers_per_block[idx] |
|
|
56 |
self.dense_block_list.append(dense_layer(num_conv_layers, |
|
|
57 |
growth_rate, |
|
|
58 |
kernel_size, |
|
|
59 |
dropout_rate, |
|
|
60 |
nonlinearity, |
|
|
61 |
use_dropout=False, |
|
|
62 |
use_concat=True)) |
|
|
63 |
|
|
|
64 |
layers_counter = layers_counter + num_conv_layers |
|
|
65 |
num_filters = num_channels + layers_counter * growth_rate |
|
|
66 |
|
|
|
67 |
if idx != len(self.layers_per_block)-1: |
|
|
68 |
self.dense_block_list.append(down_transition(num_channels=num_filters, |
|
|
69 |
kernel_size=(1, 1), |
|
|
70 |
pool_size=(2, 2), |
|
|
71 |
dropout_rate=0.2, |
|
|
72 |
nonlinearity='relu', |
|
|
73 |
use_dropout=False)) |
|
|
74 |
|
|
|
75 |
for idx in range(len(self.layers_per_block) - 1, 0, -1): |
|
|
76 |
num_conv_layers = layers_per_block[idx - 1] |
|
|
77 |
num_filters = num_conv_layers * growth_rate |
|
|
78 |
self.up_transition_list.append(up_transition(num_conv_layers, |
|
|
79 |
num_channels=num_filters, |
|
|
80 |
growth_rate=self.growth_rate, |
|
|
81 |
kernel_size=(3, 3), |
|
|
82 |
strides=(2, 2), |
|
|
83 |
padding='same', |
|
|
84 |
use_concat=False)) |
|
|
85 |
|
|
|
86 |
def call(self, inputs, training=False): |
|
|
87 |
blocks = [] |
|
|
88 |
x = self.conv_3x3(inputs) |
|
|
89 |
for i, down in enumerate(self.dense_block_list): |
|
|
90 |
x = down(x, training=training) |
|
|
91 |
if i % 2 == 0 and i != len(self.dense_block_list)-1: |
|
|
92 |
blocks.append(x) |
|
|
93 |
|
|
|
94 |
for i, up in enumerate(self.up_transition_list): |
|
|
95 |
x = up(x, blocks[- i-1], training=training) |
|
|
96 |
|
|
|
97 |
x = self.conv_1x1(x) |
|
|
98 |
if self.num_classes == 1: |
|
|
99 |
output = tfkl.Activation('sigmoid')(x) |
|
|
100 |
else: |
|
|
101 |
output = tfkl.Activation('softmax')(x) |
|
|
102 |
return output |
|
|
103 |
|
|
|
104 |
'''------------------------------------------------------------------''' |
|
|
105 |
|
|
|
106 |
class conv_layer(tf.keras.Sequential): |
|
|
107 |
|
|
|
108 |
def __init__(self, |
|
|
109 |
num_channels, |
|
|
110 |
kernel_size=(3, 3), |
|
|
111 |
dropout_rate=0.2, |
|
|
112 |
nonlinearity='relu', |
|
|
113 |
use_dropout=False, |
|
|
114 |
**kwargs): |
|
|
115 |
|
|
|
116 |
super(conv_layer, self).__init__(**kwargs) |
|
|
117 |
|
|
|
118 |
self.num_channels = num_channels |
|
|
119 |
self.kernel_size = kernel_size |
|
|
120 |
self.dropout_rate = dropout_rate |
|
|
121 |
self.nonlinearity = nonlinearity |
|
|
122 |
self.use_dropout = use_dropout |
|
|
123 |
|
|
|
124 |
self.add(tfkl.BatchNormalization(axis=-1, |
|
|
125 |
momentum=0.95, |
|
|
126 |
epsilon=0.001)) |
|
|
127 |
|
|
|
128 |
self.add(tfkl.Activation(self.nonlinearity)) |
|
|
129 |
|
|
|
130 |
self.add(tfkl.Conv2D(self.num_channels, |
|
|
131 |
self.kernel_size, |
|
|
132 |
padding='same', |
|
|
133 |
activation=None, |
|
|
134 |
use_bias=True)) |
|
|
135 |
|
|
|
136 |
if use_dropout: |
|
|
137 |
self.add(tfkl.Dropout(rate=self.dropout_rate)) |
|
|
138 |
|
|
|
139 |
def call(self, inputs, training=False): |
|
|
140 |
|
|
|
141 |
outputs = super(conv_layer, self).call(inputs, training=training) |
|
|
142 |
return outputs |
|
|
143 |
|
|
|
144 |
'''-----------------------------------------------------------------''' |
|
|
145 |
|
|
|
146 |
class dense_layer(tf.keras.Sequential): |
|
|
147 |
|
|
|
148 |
def __init__(self, |
|
|
149 |
num_conv_layers, |
|
|
150 |
growth_rate, |
|
|
151 |
kernel_size=(3, 3), |
|
|
152 |
dropout_rate=0.2, |
|
|
153 |
nonlinearity='relu', |
|
|
154 |
use_dropout=False, |
|
|
155 |
use_concat=True, |
|
|
156 |
**kwargs): |
|
|
157 |
|
|
|
158 |
super(dense_layer, self).__init__(**kwargs) |
|
|
159 |
|
|
|
160 |
self.num_conv_layers = num_conv_layers |
|
|
161 |
self.growth_rate = growth_rate |
|
|
162 |
self.kernel_size = kernel_size |
|
|
163 |
self.dropout_rate = dropout_rate |
|
|
164 |
self.nonlinearity = nonlinearity |
|
|
165 |
self.use_dropout = use_dropout |
|
|
166 |
self.use_concat = use_concat |
|
|
167 |
|
|
|
168 |
self.conv_list = [] |
|
|
169 |
for layer in range(num_conv_layers): |
|
|
170 |
self.conv_list.append(conv_layer(num_channels=self.growth_rate, |
|
|
171 |
kernel_size=self.kernel_size, |
|
|
172 |
dropout_rate=self.dropout_rate, |
|
|
173 |
nonlinearity=self.nonlinearity, |
|
|
174 |
use_dropout=self.use_dropout)) |
|
|
175 |
|
|
|
176 |
def call(self, inputs, training=False): |
|
|
177 |
dense_output = [] |
|
|
178 |
x = inputs |
|
|
179 |
for i, conv in enumerate(self.conv_list): |
|
|
180 |
out = conv(x, training=training) |
|
|
181 |
x = tfkl.concatenate([x, out], axis=-1) |
|
|
182 |
dense_output.append(out) |
|
|
183 |
|
|
|
184 |
x = tfkl.concatenate(dense_output, axis=-1) |
|
|
185 |
|
|
|
186 |
if self.use_concat: |
|
|
187 |
x = tfkl.concatenate([x, inputs], axis=-1) |
|
|
188 |
|
|
|
189 |
outputs = x |
|
|
190 |
return outputs |
|
|
191 |
|
|
|
192 |
'''-----------------------------------------------------------------''' |
|
|
193 |
|
|
|
194 |
class down_transition(tf.keras.Sequential): |
|
|
195 |
|
|
|
196 |
def __init__(self, |
|
|
197 |
num_channels, |
|
|
198 |
kernel_size=(1, 1), |
|
|
199 |
pool_size=(2, 2), |
|
|
200 |
dropout_rate=0.2, |
|
|
201 |
nonlinearity='relu', |
|
|
202 |
use_dropout=False, |
|
|
203 |
**kwargs): |
|
|
204 |
|
|
|
205 |
super(down_transition, self).__init__(**kwargs) |
|
|
206 |
|
|
|
207 |
self.kernel_size = kernel_size |
|
|
208 |
self.pool_size = pool_size |
|
|
209 |
self.dropout_rate = dropout_rate |
|
|
210 |
self.nonlinearity = nonlinearity |
|
|
211 |
self.use_dropout = use_dropout |
|
|
212 |
|
|
|
213 |
self.add(tfkl.BatchNormalization(axis=-1, |
|
|
214 |
momentum=0.95, |
|
|
215 |
epsilon=0.001)) |
|
|
216 |
self.add(tfkl.Activation(nonlinearity)) |
|
|
217 |
self.add(tfkl.Conv2D(num_channels, kernel_size, padding='same')) |
|
|
218 |
|
|
|
219 |
if use_dropout: |
|
|
220 |
self.add(tfkl.Dropout(rate=self.dropout_rate)) |
|
|
221 |
|
|
|
222 |
self.add(tfkl.MaxPooling2D(pool_size)) |
|
|
223 |
|
|
|
224 |
def call(self, inputs, training=False): |
|
|
225 |
|
|
|
226 |
outputs = super(down_transition, self).call(inputs, training=training) |
|
|
227 |
|
|
|
228 |
return outputs |
|
|
229 |
|
|
|
230 |
'''-----------------------------------------------------------------''' |
|
|
231 |
|
|
|
232 |
class up_transition(tf.keras.Model): |
|
|
233 |
|
|
|
234 |
def __init__(self, |
|
|
235 |
num_conv_layers, |
|
|
236 |
num_channels, |
|
|
237 |
growth_rate, |
|
|
238 |
kernel_size=(3, 3), |
|
|
239 |
strides=(2, 2), |
|
|
240 |
padding='same', |
|
|
241 |
nonlinearity='relu', |
|
|
242 |
use_concat=False, |
|
|
243 |
**kwargs): |
|
|
244 |
|
|
|
245 |
super(up_transition, self).__init__(**kwargs) |
|
|
246 |
|
|
|
247 |
self.num_conv_layers = num_conv_layers |
|
|
248 |
self.num_channels = num_channels |
|
|
249 |
self.growth_rate = growth_rate |
|
|
250 |
self.kernel_size = kernel_size |
|
|
251 |
self.strides = strides |
|
|
252 |
self.padding = padding |
|
|
253 |
self.nonlinearity = nonlinearity |
|
|
254 |
self.use_concat = use_concat |
|
|
255 |
|
|
|
256 |
self.up_conv = tfkl.Conv2DTranspose(num_channels, |
|
|
257 |
kernel_size, |
|
|
258 |
strides, |
|
|
259 |
padding) |
|
|
260 |
|
|
|
261 |
self.dense_block = dense_layer(num_conv_layers, |
|
|
262 |
growth_rate, |
|
|
263 |
kernel_size, |
|
|
264 |
strides, |
|
|
265 |
nonlinearity, |
|
|
266 |
use_concat=self.use_concat) |
|
|
267 |
|
|
|
268 |
def call(self, inputs, bridge, training=False): |
|
|
269 |
|
|
|
270 |
up = self.up_conv(inputs, training=training) |
|
|
271 |
db_up = self.dense_block(up, training=training) |
|
|
272 |
c_up = tfkl.concatenate([db_up, bridge], axis=3) |
|
|
273 |
|
|
|
274 |
return c_up |