Switch to unified view

a b/fetal_net/model/resnet/resnet.py
1
from __future__ import division
2
3
import six
4
from keras.models import Model
5
from keras.layers import (
6
    Input,
7
    Activation,
8
    Dense,
9
    Flatten
10
)
11
from keras.layers.convolutional import (
12
    Conv2D,
13
    MaxPooling2D,
14
    AveragePooling2D
15
)
16
from keras.layers.merge import add
17
from keras.layers.normalization import BatchNormalization
18
from keras.regularizers import l2
19
from keras import backend as K
20
21
22
def _bn_relu(input):
23
    """Helper to build a BN -> relu block
24
    """
25
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
26
    return Activation("relu")(norm)
27
28
29
def _conv_bn_relu(**conv_params):
30
    """Helper to build a conv -> BN -> relu block
31
    """
32
    filters = conv_params["filters"]
33
    kernel_size = conv_params["kernel_size"]
34
    strides = conv_params.setdefault("strides", (1, 1))
35
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
36
    padding = conv_params.setdefault("padding", "same")
37
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))
38
39
    def f(input):
40
        conv = Conv2D(filters=filters, kernel_size=kernel_size,
41
                      strides=strides, padding=padding,
42
                      kernel_initializer=kernel_initializer,
43
                      kernel_regularizer=kernel_regularizer)(input)
44
        return _bn_relu(conv)
45
46
    return f
47
48
49
def _bn_relu_conv(**conv_params):
50
    """Helper to build a BN -> relu -> conv block.
51
    This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
52
    """
53
    filters = conv_params["filters"]
54
    kernel_size = conv_params["kernel_size"]
55
    strides = conv_params.setdefault("strides", (1, 1))
56
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
57
    padding = conv_params.setdefault("padding", "same")
58
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))
59
60
    def f(input):
61
        activation = _bn_relu(input)
62
        return Conv2D(filters=filters, kernel_size=kernel_size,
63
                      strides=strides, padding=padding,
64
                      kernel_initializer=kernel_initializer,
65
                      kernel_regularizer=kernel_regularizer)(activation)
66
67
    return f
68
69
70
def _shortcut(input, residual):
71
    """Adds a shortcut between input and residual block and merges them with "sum"
72
    """
73
    # Expand channels of shortcut to match residual.
74
    # Stride appropriately to match residual (width, height)
75
    # Should be int if network architecture is correctly configured.
76
    input_shape = K.int_shape(input)
77
    residual_shape = K.int_shape(residual)
78
    stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))
79
    stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))
80
    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]
81
82
    shortcut = input
83
    # 1 X 1 conv if shape is different. Else identity.
84
    if stride_width > 1 or stride_height > 1 or not equal_channels:
85
        shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
86
                          kernel_size=(1, 1),
87
                          strides=(stride_width, stride_height),
88
                          padding="valid",
89
                          kernel_initializer="he_normal",
90
                          kernel_regularizer=l2(0.0001))(input)
91
92
    return add([shortcut, residual])
93
94
95
def _residual_block(block_function, filters, repetitions, is_first_layer=False):
96
    """Builds a residual block with repeating bottleneck blocks.
97
    """
98
    def f(input):
99
        for i in range(repetitions):
100
            init_strides = (1, 1)
101
            if i == 0 and not is_first_layer:
102
                init_strides = (2, 2)
103
            input = block_function(filters=filters, init_strides=init_strides,
104
                                   is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
105
        return input
106
107
    return f
108
109
110
def basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
111
    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
112
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
113
    """
114
    def f(input):
115
116
        if is_first_block_of_first_layer:
117
            # don't repeat bn->relu since we just did bn->relu->maxpool
118
            conv1 = Conv2D(filters=filters, kernel_size=(3, 3),
119
                           strides=init_strides,
120
                           padding="same",
121
                           kernel_initializer="he_normal",
122
                           kernel_regularizer=l2(1e-4))(input)
123
        else:
124
            conv1 = _bn_relu_conv(filters=filters, kernel_size=(3, 3),
125
                                  strides=init_strides)(input)
126
127
        residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv1)
128
        return _shortcut(input, residual)
129
130
    return f
131
132
133
def bottleneck(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
134
    """Bottleneck architecture for > 34 layer resnet.
135
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
136
137
    Returns:
138
        A final conv layer of filters * 4
139
    """
140
    def f(input):
141
142
        if is_first_block_of_first_layer:
143
            # don't repeat bn->relu since we just did bn->relu->maxpool
144
            conv_1_1 = Conv2D(filters=filters, kernel_size=(1, 1),
145
                              strides=init_strides,
146
                              padding="same",
147
                              kernel_initializer="he_normal",
148
                              kernel_regularizer=l2(1e-4))(input)
149
        else:
150
            conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1),
151
                                     strides=init_strides)(input)
152
153
        conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv_1_1)
154
        residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1))(conv_3_3)
155
        return _shortcut(input, residual)
156
157
    return f
158
159
160
def _handle_dim_ordering():
161
    global ROW_AXIS
162
    global COL_AXIS
163
    global CHANNEL_AXIS
164
    if K.image_dim_ordering() == 'tf':
165
        ROW_AXIS = 1
166
        COL_AXIS = 2
167
        CHANNEL_AXIS = 3
168
    else:
169
        CHANNEL_AXIS = 1
170
        ROW_AXIS = 2
171
        COL_AXIS = 3
172
173
174
def _get_block(identifier):
175
    if isinstance(identifier, six.string_types):
176
        res = globals().get(identifier)
177
        if not res:
178
            raise ValueError('Invalid {}'.format(identifier))
179
        return res
180
    return identifier
181
182
183
class ResnetBuilder(object):
184
    @staticmethod
185
    def build(input_shape, num_outputs, block_fn, repetitions):
186
        """Builds a custom ResNet like architecture.
187
188
        Args:
189
            input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
190
            num_outputs: The number of outputs at final softmax layer
191
            block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
192
                The original paper used basic_block for layers < 50
193
            repetitions: Number of repetitions of various block units.
194
                At each block unit, the number of filters are doubled and the input size is halved
195
196
        Returns:
197
            The keras `Model`.
198
        """
199
        _handle_dim_ordering()
200
        if len(input_shape) != 3:
201
            raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")
202
203
        # Permute dimension order if necessary
204
        if K.image_dim_ordering() == 'tf':
205
            input_shape = (input_shape[1], input_shape[2], input_shape[0])
206
207
        # Load function from str if needed.
208
        block_fn = _get_block(block_fn)
209
210
        input = Input(shape=input_shape)
211
        conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(input)
212
        pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)
213
214
        block = pool1
215
        filters = 64
216
        for i, r in enumerate(repetitions):
217
            block = _residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block)
218
            filters *= 2
219
220
        # Last activation
221
        block = _bn_relu(block)
222
223
        # Classifier block
224
        block_shape = K.int_shape(block)
225
        pool2 = AveragePooling2D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS]),
226
                                 strides=(1, 1))(block)
227
        flatten1 = Flatten()(pool2)
228
        dense = Dense(units=num_outputs, kernel_initializer="he_normal",
229
                      activation="softmax")(flatten1)
230
231
        model = Model(inputs=input, outputs=dense)
232
        return model
233
234
    @staticmethod
235
    def build_resnet_18(input_shape, num_outputs):
236
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2])
237
238
    @staticmethod
239
    def build_resnet_34(input_shape, num_outputs):
240
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3])
241
242
    @staticmethod
243
    def build_resnet_50(input_shape, num_outputs):
244
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3])
245
246
    @staticmethod
247
    def build_resnet_101(input_shape, num_outputs):
248
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3])
249
250
    @staticmethod
251
    def build_resnet_152(input_shape, num_outputs):
252
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3])