Switch to unified view

a b/fetal_net/model/resnet/resnet3d.py
1
"""A vanilla 3D resnet implementation.
2
3
Based on Raghavendra Kotikalapudi's 2D implementation
4
keras-resnet (See https://github.com/raghakot/keras-resnet.)
5
"""
6
from __future__ import (
7
    absolute_import,
8
    division,
9
    print_function,
10
    unicode_literals
11
)
12
import six
13
from keras.models import Model
14
from keras.layers import (
15
    Input,
16
    Activation,
17
    Dense,
18
    Flatten
19
)
20
from keras.layers.convolutional import (
21
    Conv3D,
22
    AveragePooling3D,
23
    MaxPooling3D
24
)
25
from keras.layers.merge import add
26
from keras.layers.normalization import BatchNormalization
27
from keras.regularizers import l2
28
from keras import backend as K
29
30
31
def _bn_relu(input):
32
    """Helper to build a BN -> relu block (by @raghakot)."""
33
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
34
    return Activation("relu")(norm)
35
36
37
def _conv_bn_relu3D(**conv_params):
38
    filters = conv_params["filters"]
39
    kernel_size = conv_params["kernel_size"]
40
    strides = conv_params.setdefault("strides", (1, 1, 1))
41
    kernel_initializer = conv_params.setdefault(
42
        "kernel_initializer", "he_normal")
43
    padding = conv_params.setdefault("padding", "same")
44
    kernel_regularizer = conv_params.setdefault("kernel_regularizer",
45
                                                l2(1e-4))
46
47
    def f(input):
48
        conv = Conv3D(filters=filters, kernel_size=kernel_size,
49
                      strides=strides, kernel_initializer=kernel_initializer,
50
                      padding=padding,
51
                      kernel_regularizer=kernel_regularizer)(input)
52
        return _bn_relu(conv)
53
54
    return f
55
56
57
def _bn_relu_conv3d(**conv_params):
58
    """Helper to build a  BN -> relu -> conv3d block."""
59
    filters = conv_params["filters"]
60
    kernel_size = conv_params["kernel_size"]
61
    strides = conv_params.setdefault("strides", (1, 1, 1))
62
    kernel_initializer = conv_params.setdefault("kernel_initializer",
63
                                                "he_normal")
64
    padding = conv_params.setdefault("padding", "same")
65
    kernel_regularizer = conv_params.setdefault("kernel_regularizer",
66
                                                l2(1e-4))
67
68
    def f(input):
69
        activation = _bn_relu(input)
70
        return Conv3D(filters=filters, kernel_size=kernel_size,
71
                      strides=strides, kernel_initializer=kernel_initializer,
72
                      padding=padding,
73
                      kernel_regularizer=kernel_regularizer)(activation)
74
    return f
75
76
77
def _shortcut3d(input, residual):
78
    """3D shortcut to match input and residual and merges them with "sum"."""
79
    stride_dim1 = input._keras_shape[DIM1_AXIS] \
80
        // residual._keras_shape[DIM1_AXIS]
81
    stride_dim2 = input._keras_shape[DIM2_AXIS] \
82
        // residual._keras_shape[DIM2_AXIS]
83
    stride_dim3 = input._keras_shape[DIM3_AXIS] \
84
        // residual._keras_shape[DIM3_AXIS]
85
    equal_channels = residual._keras_shape[CHANNEL_AXIS] \
86
        == input._keras_shape[CHANNEL_AXIS]
87
88
    shortcut = input
89
    if stride_dim1 > 1 or stride_dim2 > 1 or stride_dim3 > 1 \
90
            or not equal_channels:
91
        shortcut = Conv3D(
92
            filters=residual._keras_shape[CHANNEL_AXIS],
93
            kernel_size=(1, 1, 1),
94
            strides=(stride_dim1, stride_dim2, stride_dim3),
95
            kernel_initializer="he_normal", padding="valid",
96
            kernel_regularizer=l2(1e-4)
97
            )(input)
98
    return add([shortcut, residual])
99
100
101
def _residual_block3d(block_function, filters, kernel_regularizer, repetitions,
102
                      is_first_layer=False):
103
    def f(input):
104
        for i in range(repetitions):
105
            strides = (1, 1, 1)
106
            if i == 0 and not is_first_layer:
107
                strides = (2, 2, 2)
108
            input = block_function(filters=filters, strides=strides,
109
                                   kernel_regularizer=kernel_regularizer,
110
                                   is_first_block_of_first_layer=(
111
                                       is_first_layer and i == 0)
112
                                   )(input)
113
        return input
114
115
    return f
116
117
118
def basic_block(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4),
119
                is_first_block_of_first_layer=False):
120
    """Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl."""
121
    def f(input):
122
        if is_first_block_of_first_layer:
123
            # don't repeat bn->relu since we just did bn->relu->maxpool
124
            conv1 = Conv3D(filters=filters, kernel_size=(3, 3, 3),
125
                           strides=strides, padding="same",
126
                           kernel_initializer="he_normal",
127
                           kernel_regularizer=kernel_regularizer
128
                           )(input)
129
        else:
130
            conv1 = _bn_relu_conv3d(filters=filters,
131
                                    kernel_size=(3, 3, 3),
132
                                    strides=strides,
133
                                    kernel_regularizer=kernel_regularizer
134
                                    )(input)
135
136
        residual = _bn_relu_conv3d(filters=filters, kernel_size=(3, 3, 3),
137
                                   kernel_regularizer=kernel_regularizer
138
                                   )(conv1)
139
        return _shortcut3d(input, residual)
140
141
    return f
142
143
144
def bottleneck(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4),
145
               is_first_block_of_first_layer=False):
146
    """Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl."""
147
    def f(input):
148
        if is_first_block_of_first_layer:
149
            # don't repeat bn->relu since we just did bn->relu->maxpool
150
            conv_1_1 = Conv3D(filters=filters, kernel_size=(1, 1, 1),
151
                              strides=strides, padding="same",
152
                              kernel_initializer="he_normal",
153
                              kernel_regularizer=kernel_regularizer
154
                              )(input)
155
        else:
156
            conv_1_1 = _bn_relu_conv3d(filters=filters, kernel_size=(1, 1, 1),
157
                                       strides=strides,
158
                                       kernel_regularizer=kernel_regularizer
159
                                       )(input)
160
161
        conv_3_3 = _bn_relu_conv3d(filters=filters, kernel_size=(3, 3, 3),
162
                                   kernel_regularizer=kernel_regularizer
163
                                   )(conv_1_1)
164
        residual = _bn_relu_conv3d(filters=filters * 4, kernel_size=(1, 1, 1),
165
                                   kernel_regularizer=kernel_regularizer
166
                                   )(conv_3_3)
167
168
        return _shortcut3d(input, residual)
169
170
    return f
171
172
173
def _handle_data_format():
174
    global DIM1_AXIS
175
    global DIM2_AXIS
176
    global DIM3_AXIS
177
    global CHANNEL_AXIS
178
    if K.image_data_format() == 'channels_last':
179
        DIM1_AXIS = 1
180
        DIM2_AXIS = 2
181
        DIM3_AXIS = 3
182
        CHANNEL_AXIS = 4
183
    else:
184
        CHANNEL_AXIS = 1
185
        DIM1_AXIS = 2
186
        DIM2_AXIS = 3
187
        DIM3_AXIS = 4
188
189
190
def _get_block(identifier):
191
    if isinstance(identifier, six.string_types):
192
        res = globals().get(identifier)
193
        if not res:
194
            raise ValueError('Invalid {}'.format(identifier))
195
        return res
196
    return identifier
197
198
199
class Resnet3DBuilder(object):
200
    """ResNet3D."""
201
202
    @staticmethod
203
    def build(input_shape, num_outputs, block_fn, repetitions, reg_factor, max_filters):
204
        """Instantiate a vanilla ResNet3D keras model.
205
206
        # Arguments
207
            input_shape: Tuple of input shape in the format
208
            (conv_dim1, conv_dim2, conv_dim3, channels) if dim_ordering='tf'
209
            (filter, conv_dim1, conv_dim2, conv_dim3) if dim_ordering='th'
210
            num_outputs: The number of outputs at the final softmax layer
211
            block_fn: Unit block to use {'basic_block', 'bottlenack_block'}
212
            repetitions: Repetitions of unit blocks
213
        # Returns
214
            model: a 3D ResNet model that takes a 5D tensor (volumetric images
215
            in batch) as input and returns a 1D vector (prediction) as output.
216
        """
217
        _handle_data_format()
218
        if len(input_shape) != 4:
219
            raise ValueError("Input shape should be a tuple "
220
                             "(conv_dim1, conv_dim2, conv_dim3, channels) "
221
                             "for tensorflow as backend or "
222
                             "(channels, conv_dim1, conv_dim2, conv_dim3) "
223
                             "for theano as backend")
224
225
        block_fn = _get_block(block_fn)
226
        input = Input(shape=input_shape)
227
        # first conv
228
        conv1 = _conv_bn_relu3D(filters=64, kernel_size=(7, 7, 7),
229
                                strides=(2, 2, 2),
230
                                kernel_regularizer=l2(reg_factor)
231
                                )(input)
232
        pool1 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2),
233
                             padding="same")(conv1)
234
235
        # repeat blocks
236
        block = pool1
237
        filters = 64
238
        for i, r in enumerate(repetitions):
239
            block = _residual_block3d(block_fn, filters=filters,
240
                                      kernel_regularizer=l2(reg_factor),
241
                                      repetitions=r, is_first_layer=(i == 0)
242
                                      )(block)
243
            filters *= 2
244
            filters = min(max_filters, filters)
245
246
        # last activation
247
        block_output = _bn_relu(block)
248
249
        # average poll and classification
250
        pool2 = AveragePooling3D(pool_size=(block._keras_shape[DIM1_AXIS],
251
                                            block._keras_shape[DIM2_AXIS],
252
                                            block._keras_shape[DIM3_AXIS]),
253
                                 strides=(1, 1, 1))(block_output)
254
        flatten1 = Flatten()(pool2)
255
        if num_outputs > 1:
256
            dense = Dense(units=num_outputs,
257
                          kernel_initializer="he_normal",
258
                          activation="softmax",
259
                          kernel_regularizer=l2(reg_factor))(flatten1)
260
        else:
261
            dense = Dense(units=num_outputs,
262
                          kernel_initializer="he_normal",
263
                          activation="sigmoid",
264
                          kernel_regularizer=l2(reg_factor))(flatten1)
265
266
        model = Model(inputs=input, outputs=dense)
267
        return model
268
269
    @staticmethod
270
    def build_resnet_18(input_shape, num_outputs, reg_factor=1e-4, max_filters=256):
271
        """Build resnet 18."""
272
        return Resnet3DBuilder.build(input_shape, num_outputs, basic_block,
273
                                     [2, 2, 2, 2], reg_factor=reg_factor, max_filters=256)
274
275
    @staticmethod
276
    def build_resnet_34(input_shape, num_outputs, reg_factor=1e-4):
277
        """Build resnet 34."""
278
        return Resnet3DBuilder.build(input_shape, num_outputs, basic_block,
279
                                     [3, 4, 6, 3], reg_factor=reg_factor)
280
281
    @staticmethod
282
    def build_resnet_50(input_shape, num_outputs, reg_factor=1e-4):
283
        """Build resnet 50."""
284
        return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck,
285
                                     [3, 4, 6, 3], reg_factor=reg_factor)
286
287
    @staticmethod
288
    def build_resnet_101(input_shape, num_outputs, reg_factor=1e-4):
289
        """Build resnet 101."""
290
        return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck,
291
                                     [3, 4, 23, 3], reg_factor=reg_factor)
292
293
    @staticmethod
294
    def build_resnet_152(input_shape, num_outputs, reg_factor=1e-4):
295
        """Build resnet 152."""
296
        return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck,
297
                                     [3, 8, 36, 3], reg_factor=reg_factor)