a b/Segmentation/model/Hundred_Layer_Tiramisu.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
4
'''The implementation of the 100 layer Tiramisu Network follows
5
directly from the publication found at https://arxiv.org/pdf/1611.09326.pdf'''
6
7
class Hundred_Layer_Tiramisu(tf.keras.Model):
8
    def __init__(self,
9
                 growth_rate,
10
                 layers_per_block,
11
                 num_channels,
12
                 num_classes,
13
                 kernel_size=(3, 3),
14
                 pool_size=(2, 2),
15
                 nonlinearity='relu',
16
                 dropout_rate=0.2,
17
                 strides=(2, 2),
18
                 padding='same',
19
                 use_dropout=False,
20
                 use_concat=True,
21
                 **kwargs):
22
23
        super(Hundred_Layer_Tiramisu, self).__init__(**kwargs)
24
25
        self.growth_rate = growth_rate
26
        self.layers_per_block = layers_per_block
27
        self.num_channels = num_channels
28
        self.num_classes = num_classes
29
        self.kernel_size = kernel_size
30
        self.pool_size = pool_size
31
        self.nonlinearity = nonlinearity
32
        self.dropout_rate = dropout_rate
33
        self.strides = strides
34
        self.padding = padding
35
        self.use_dropout = use_dropout
36
        self.use_concat = use_concat
37
38
        self.conv_3x3 = tfkl.Conv2D(self.num_channels,
39
                                    kernel_size,
40
                                    padding='same')
41
        self.dense_block_list = []
42
        self.up_transition_list = []   
43
44
        self.conv_1x1 = tfkl.Conv2D(filters=num_classes,
45
                                    kernel_size=(1, 1),
46
                                    padding='same')
47
48
        layers_counter = 0
49
        num_filters = num_channels
50
51
        print(len(self.layers_per_block))
52
53
        for idx in range(0, len(self.layers_per_block)):
54
            print(idx)
55
            num_conv_layers = layers_per_block[idx]
56
            self.dense_block_list.append(dense_layer(num_conv_layers,
57
                                                     growth_rate,
58
                                                     kernel_size,
59
                                                     dropout_rate,
60
                                                     nonlinearity,
61
                                                     use_dropout=False,
62
                                                     use_concat=True))
63
64
            layers_counter = layers_counter + num_conv_layers
65
            num_filters = num_channels + layers_counter * growth_rate
66
67
            if idx != len(self.layers_per_block)-1:
68
                self.dense_block_list.append(down_transition(num_channels=num_filters,
69
                                                             kernel_size=(1, 1),
70
                                                             pool_size=(2, 2),
71
                                                             dropout_rate=0.2,
72
                                                             nonlinearity='relu',
73
                                                             use_dropout=False))
74
75
        for idx in range(len(self.layers_per_block) - 1, 0, -1):
76
            num_conv_layers = layers_per_block[idx - 1]
77
            num_filters = num_conv_layers * growth_rate
78
            self.up_transition_list.append(up_transition(num_conv_layers,
79
                                                         num_channels=num_filters,
80
                                                         growth_rate=self.growth_rate,
81
                                                         kernel_size=(3, 3),
82
                                                         strides=(2, 2),
83
                                                         padding='same',
84
                                                         use_concat=False))
85
86
    def call(self, inputs, training=False):
87
        blocks = []
88
        x = self.conv_3x3(inputs)
89
        for i, down in enumerate(self.dense_block_list):
90
            x = down(x, training=training)
91
            if i % 2 == 0 and i != len(self.dense_block_list)-1:
92
                blocks.append(x)
93
94
        for i, up in enumerate(self.up_transition_list):
95
            x = up(x, blocks[- i-1], training=training)
96
97
        x = self.conv_1x1(x)
98
        if self.num_classes == 1:
99
            output = tfkl.Activation('sigmoid')(x)
100
        else:
101
            output = tfkl.Activation('softmax')(x)
102
        return output
103
104
'''------------------------------------------------------------------'''
105
106
class conv_layer(tf.keras.Sequential):
107
108
    def __init__(self,
109
                 num_channels,
110
                 kernel_size=(3, 3),
111
                 dropout_rate=0.2,
112
                 nonlinearity='relu',
113
                 use_dropout=False,
114
                 **kwargs):
115
116
        super(conv_layer, self).__init__(**kwargs)
117
118
        self.num_channels = num_channels
119
        self.kernel_size = kernel_size
120
        self.dropout_rate = dropout_rate
121
        self.nonlinearity = nonlinearity
122
        self.use_dropout = use_dropout
123
124
        self.add(tfkl.BatchNormalization(axis=-1,
125
                                         momentum=0.95,
126
                                         epsilon=0.001))
127
128
        self.add(tfkl.Activation(self.nonlinearity))
129
130
        self.add(tfkl.Conv2D(self.num_channels,
131
                             self.kernel_size,
132
                             padding='same',
133
                             activation=None, 
134
                             use_bias=True))
135
136
        if use_dropout:
137
            self.add(tfkl.Dropout(rate=self.dropout_rate))
138
139
    def call(self, inputs, training=False):
140
141
        outputs = super(conv_layer, self).call(inputs, training=training)    
142
        return outputs
143
144
'''-----------------------------------------------------------------'''
145
146
class dense_layer(tf.keras.Sequential):
147
148
    def __init__(self,
149
                 num_conv_layers,
150
                 growth_rate,
151
                 kernel_size=(3, 3),
152
                 dropout_rate=0.2,
153
                 nonlinearity='relu',
154
                 use_dropout=False,
155
                 use_concat=True,
156
                 **kwargs):
157
158
        super(dense_layer, self).__init__(**kwargs)
159
160
        self.num_conv_layers = num_conv_layers
161
        self.growth_rate = growth_rate
162
        self.kernel_size = kernel_size
163
        self.dropout_rate = dropout_rate
164
        self.nonlinearity = nonlinearity
165
        self.use_dropout = use_dropout
166
        self.use_concat = use_concat
167
168
        self.conv_list = []
169
        for layer in range(num_conv_layers):
170
            self.conv_list.append(conv_layer(num_channels=self.growth_rate,
171
                                             kernel_size=self.kernel_size,
172
                                             dropout_rate=self.dropout_rate,
173
                                             nonlinearity=self.nonlinearity,
174
                                             use_dropout=self.use_dropout))
175
176
    def call(self, inputs, training=False):
177
        dense_output = []
178
        x = inputs
179
        for i, conv in enumerate(self.conv_list):
180
            out = conv(x, training=training)
181
            x = tfkl.concatenate([x, out], axis=-1)
182
            dense_output.append(out)
183
184
        x = tfkl.concatenate(dense_output, axis=-1)
185
186
        if self.use_concat:
187
            x = tfkl.concatenate([x, inputs], axis=-1)
188
189
        outputs = x
190
        return outputs
191
192
'''-----------------------------------------------------------------'''
193
194
class down_transition(tf.keras.Sequential):
195
196
    def __init__(self,
197
                 num_channels,
198
                 kernel_size=(1, 1),
199
                 pool_size=(2, 2),
200
                 dropout_rate=0.2,
201
                 nonlinearity='relu',
202
                 use_dropout=False,
203
                 **kwargs):
204
205
        super(down_transition, self).__init__(**kwargs)
206
207
        self.kernel_size = kernel_size
208
        self.pool_size = pool_size
209
        self.dropout_rate = dropout_rate
210
        self.nonlinearity = nonlinearity
211
        self.use_dropout = use_dropout
212
213
        self.add(tfkl.BatchNormalization(axis=-1,
214
                                         momentum=0.95,
215
                                         epsilon=0.001))
216
        self.add(tfkl.Activation(nonlinearity))
217
        self.add(tfkl.Conv2D(num_channels, kernel_size, padding='same'))
218
219
        if use_dropout:
220
            self.add(tfkl.Dropout(rate=self.dropout_rate))
221
222
        self.add(tfkl.MaxPooling2D(pool_size))
223
    
224
    def call(self, inputs, training=False):
225
226
        outputs = super(down_transition, self).call(inputs, training=training)
227
228
        return outputs
229
230
'''-----------------------------------------------------------------'''
231
232
class up_transition(tf.keras.Model):
233
234
    def __init__(self,
235
                 num_conv_layers,
236
                 num_channels,
237
                 growth_rate,
238
                 kernel_size=(3, 3),
239
                 strides=(2, 2),
240
                 padding='same',
241
                 nonlinearity='relu',
242
                 use_concat=False,
243
                 **kwargs):
244
        
245
        super(up_transition, self).__init__(**kwargs)
246
247
        self.num_conv_layers = num_conv_layers
248
        self.num_channels = num_channels
249
        self.growth_rate = growth_rate
250
        self.kernel_size = kernel_size
251
        self.strides = strides
252
        self.padding = padding
253
        self.nonlinearity = nonlinearity
254
        self.use_concat = use_concat
255
256
        self.up_conv = tfkl.Conv2DTranspose(num_channels,
257
                                            kernel_size,
258
                                            strides,
259
                                            padding)
260
261
        self.dense_block = dense_layer(num_conv_layers,
262
                                       growth_rate, 
263
                                       kernel_size,
264
                                       strides,
265
                                       nonlinearity,
266
                                       use_concat=self.use_concat)
267
268
    def call(self, inputs, bridge, training=False):
269
        
270
        up = self.up_conv(inputs, training=training)
271
        db_up = self.dense_block(up, training=training)
272
        c_up = tfkl.concatenate([db_up, bridge], axis=3)
273
274
        return c_up