Diff of /DESS/resnet3d.py [000000] .. [6a4082]

Switch to unified view

a b/DESS/resnet3d.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
import six
21
from keras.models import Model
22
from keras.layers import (
23
    Input,
24
    Activation,
25
    Dense,
26
    Flatten,
27
    GlobalMaxPooling3D,
28
    GlobalAveragePooling3D
29
)
30
from keras.layers import Dropout, Dense, Conv3D, MaxPooling3D,GlobalMaxPooling3D, GlobalAveragePooling3D, Activation, BatchNormalization,Flatten
31
32
from keras.layers.convolutional import (
33
    Conv3D,
34
    AveragePooling3D,
35
    MaxPooling3D
36
)
37
from keras.layers.merge import add
38
from keras.layers.normalization import BatchNormalization
39
from keras.regularizers import l2
40
from keras import backend as K
41
from math import ceil
42
43
def _bn_relu(input):
44
    """Helper to build a BN -> relu block (copy-paster from
45
    raghakot/keras-resnet)
46
    """
47
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
48
    return Activation("relu")(norm)
49
50
51
def _conv_bn_relu3D(**conv_params):
52
    filters = conv_params["filters"]
53
    kernel_size = conv_params["kernel_size"]
54
    strides = conv_params.setdefault("strides", (1, 1, 1))
55
    kernel_initializer = conv_params.setdefault(
56
        "kernel_initializer", "he_normal")
57
    padding = conv_params.setdefault("padding", "same")
58
    kernel_regularizer = conv_params.setdefault("kernel_regularizer",
59
                                                l2(5e-3))
60
61
    def f(input):
62
        conv = Conv3D(filters=filters, kernel_size=kernel_size,
63
                      strides=strides, kernel_initializer=kernel_initializer,
64
                      padding=padding,
65
                      kernel_regularizer=kernel_regularizer)(input)
66
        return _bn_relu(conv)
67
68
    return f
69
70
71
def _bn_relu_conv3d(**conv_params):
72
    """Helper to build a  BN -> relu -> conv3d block.
73
    """
74
    filters = conv_params["filters"]
75
    kernel_size = conv_params["kernel_size"]
76
    strides = conv_params.setdefault("strides", (1, 1, 1))
77
    kernel_initializer = conv_params.setdefault("kernel_initializer",
78
                                                "he_normal")
79
    padding = conv_params.setdefault("padding", "same")
80
    kernel_regularizer = conv_params.setdefault("kernel_regularizer",
81
                                                l2(5e-3))
82
83
    def f(input):
84
        activation = _bn_relu(input)
85
86
        return Conv3D(filters=filters, kernel_size=kernel_size,
87
                      strides=strides, kernel_initializer=kernel_initializer,
88
                      padding=padding,
89
                      kernel_regularizer=kernel_regularizer)(activation)
90
    return f
91
92
93
def _shortcut3d(input, residual):
94
    """3D shortcut to match input and residual and merges them with "sum" at
95
    the same time
96
    """
97
    stride_dim1 = ceil(input._keras_shape[DIM1_AXIS]*1.0 \
98
        / residual._keras_shape[DIM1_AXIS])
99
    stride_dim2 = ceil(input._keras_shape[DIM2_AXIS]*1.0 \
100
        / residual._keras_shape[DIM2_AXIS])
101
    stride_dim3 = ceil(input._keras_shape[DIM3_AXIS]*1.0 \
102
        / residual._keras_shape[DIM3_AXIS])
103
    equal_channels = residual._keras_shape[CHANNEL_AXIS] \
104
        == input._keras_shape[CHANNEL_AXIS]
105
    
106
107
    shortcut = input
108
    if stride_dim1 > 1 or stride_dim2 > 1 or stride_dim3 > 1 \
109
            or not equal_channels:
110
        shortcut = Conv3D(
111
            filters=residual._keras_shape[CHANNEL_AXIS],
112
            kernel_size=(1, 1, 1),
113
            strides=(stride_dim1, stride_dim2, stride_dim3),
114
            kernel_initializer="he_normal", padding="valid",
115
            kernel_regularizer=l2(5e-3)
116
            )(input)
117
    return add([shortcut, residual])
118
119
120
def _residual_block3d(block_function, filters, repetitions,
121
                      is_first_layer=False):
122
    def f(input):
123
        for i in range(repetitions):
124
            strides = (1, 1, 1)
125
            if i == 0 and not is_first_layer:
126
                strides = (2, 2, 2)
127
            input = block_function(filters=filters,
128
                                   strides=strides,
129
                                   is_first_block_of_first_layer=(
130
                                       is_first_layer and i == 0)
131
                                   )(input)
132
        return input
133
134
    return f
135
136
137
def basic_block(filters, strides=(1, 1, 1),
138
                is_first_block_of_first_layer=False):
139
    """Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl.
140
    """
141
    def f(input):
142
        if is_first_block_of_first_layer:
143
            # don't repeat bn->relu since we just did bn->relu->maxpool
144
            conv1 = Conv3D(filters=filters, kernel_size=(3, 3, 3),
145
                           strides=strides, kernel_initializer="he_normal",
146
                           padding="same", kernel_regularizer=l2(5e-3)
147
                           )(input)
148
        else:
149
            conv1 = _bn_relu_conv3d(filters=filters,
150
                                    kernel_size=(3, 3, 3),
151
                                    strides=strides
152
                                    )(input)
153
154
        residual = _bn_relu_conv3d(filters=filters, kernel_size=(3, 3, 3),
155
                                   )(conv1)
156
157
        return _shortcut3d(input, residual)
158
159
    return f
160
161
162
def bottleneck(filters, strides=(1, 1, 1),
163
               is_first_block_of_first_layer=False):
164
    """Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl.
165
    """
166
    def f(input):
167
        if is_first_block_of_first_layer:
168
            # don't repeat bn->relu since we just did bn->relu->maxpool
169
            conv_1_1 = Conv3D(filters=filters, kernel_size=(1, 1, 1),
170
                              strides=strides, padding="same",
171
                              kernel_initializer="he_normal",
172
                              kernel_regularizer=l2(5e-3))(input)
173
        else:
174
            conv_1_1 = _bn_relu_conv3d(filters=filters, kernel_size=(1, 1, 1),
175
                                       strides=strides)(input)
176
177
        conv_3_3 = _bn_relu_conv3d(filters=filters,
178
                                   kernel_size=(3, 3, 3))(conv_1_1)
179
        residual = _bn_relu_conv3d(filters=filters * 4,
180
                                   kernel_size=(1, 1, 1))(conv_3_3)
181
        return _shortcut3d(input, residual)
182
183
    return f
184
185
186
def _handle_data_format():
187
    global DIM1_AXIS
188
    global DIM2_AXIS
189
    global DIM3_AXIS
190
    global CHANNEL_AXIS
191
    if K.image_data_format() == 'channels_last':
192
        DIM1_AXIS = 1
193
        DIM2_AXIS = 2
194
        DIM3_AXIS = 3
195
        CHANNEL_AXIS = 4
196
    else:
197
        CHANNEL_AXIS = 1
198
        DIM1_AXIS = 2
199
        DIM2_AXIS = 3
200
        DIM3_AXIS = 4
201
202
203
def _get_block(identifier):
204
    if isinstance(identifier, six.string_types):
205
        res = globals().get(identifier)
206
        if not res:
207
            raise ValueError('Invalid {}'.format(identifier))
208
        return res
209
    return identifier
210
211
212
class Resnet3DBuilder(object):
213
    @staticmethod
214
    def build(input_shape, num_outputs,  block_fn, repetitions):
215
        """Instantiate a vanilla ResNet3D keras model
216
        # Arguments
217
            input_shape: Tuple of input shape in the format
218
            (conv_dim1, conv_dim2, conv_dim3, channels) if dim_ordering='tf'
219
            (filter, conv_dim1, conv_dim2, conv_dim3) if dim_ordering='th'
220
            num_outputs: The number of outputs at the final softmax layer
221
            block_fn: Unit block to use {'basic_block', 'bottlenack_block'}
222
            repetitions: Repetitions of unit blocks
223
        # Returns
224
            model: a 3D ResNet model that takes a 5D tensor (volumetric images
225
            in batch) as input and returns a 1D vector (prediction) as output.
226
        """
227
228
        _handle_data_format()
229
        if len(input_shape) != 4:
230
            raise Exception("Input shape should be a tuple "
231
                            "(conv_dim1, conv_dim2, conv_dim3, channels) "
232
                            "for tensorflow as backend or "
233
                            "(channels, conv_dim1, conv_dim2, conv_dim3) "
234
                            "for theano as backend")
235
        block_fn = _get_block(block_fn)
236
        input = Input(shape=input_shape)
237
        filters = 32
238
        # first conv
239
        conv1 = _conv_bn_relu3D(filters=filters, kernel_size=(7, 7, 7),
240
                                strides=(2, 2, 2))(input)
241
        pool1 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2),
242
                             padding="same")(conv1)
243
        
244
        conv2 = _conv_bn_relu3D(filters=filters, kernel_size=(3, 3, 3),
245
                                strides=(2, 2, 1))(pool1)
246
        #conv3 = _conv_bn_relu3D(filters=32, kernel_size=(3, 3, 3),
247
        #                        strides=(2, 2, 2))(conv2)
248
        #pool2 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2),
249
        #                     padding="same")(conv2)
250
        # repeat blocks
251
        block = conv2
252
        #filters = 32
253
        for i, r in enumerate(repetitions):
254
            block = _residual_block3d(block_fn, filters=filters,
255
                                      repetitions=r, is_first_layer=(i == 0)
256
                                      )(block)
257
            filters *= 2
258
259
        # last activation
260
        block_output = _bn_relu(block)
261
262
        # average poll and classification
263
        #pool3 = AveragePooling3D(pool_size=(block._keras_shape[DIM1_AXIS],
264
        #                                    block._keras_shape[DIM2_AXIS],
265
        #                                    block._keras_shape[DIM3_AXIS]),
266
        #                         strides=(1, 1, 1))(block_output)
267
        #flatten1 = Flatten()(pool3)
268
        flatten1 = GlobalAveragePooling3D()(block_output)
269
        if num_outputs > 1:
270
            dense = Dense(units=num_outputs,
271
                          kernel_initializer="he_normal",
272
                          activation="softmax",
273
                          kernel_regularizer=l2(5e-3))(flatten1)
274
        else:
275
            dense = Dense(units=num_outputs,
276
                          kernel_initializer="he_normal",
277
                          activation="sigmoid",
278
                          kernel_regularizer=l2(5e-3))(flatten1)
279
280
        model = Model(inputs=input, outputs=dense)
281
        return model
282
283
    @staticmethod
284
    def build_resnet_18(input_shape, num_outputs):
285
        return Resnet3DBuilder.build(input_shape, num_outputs, basic_block,
286
                                     [2, 2, 2, 2])
287
288
    @staticmethod
289
    def build_resnet_34(input_shape, num_outputs):
290
        return Resnet3DBuilder.build(input_shape, num_outputs, basic_block,
291
                                     [3, 4, 6, 3])
292
293
    @staticmethod
294
    def build_resnet_50(input_shape, num_outputs):
295
        return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck,
296
                                     [3, 4, 6, 3])
297
298
299
    @staticmethod
300
    def build_resnet_101(input_shape, num_outputs):
301
        return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck,
302
                                     [3, 4, 23, 3])
303
304
    @staticmethod
305
    def build_resnet_152(input_shape, num_outputs):
306
        return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck,
307
                                     [3, 8, 36, 3])