Diff of /EfficientNet_2d/utils.py [000000] .. [b76f82]

Switch to unified view

a b/EfficientNet_2d/utils.py
1
"""
2
This file contains helper functions for building the model and for loading model parameters.
3
These helper functions are built to mirror those in the official TensorFlow implementation.
4
"""
5
6
import re
7
import math
8
import collections
9
from functools import partial
10
import torch
11
from torch import nn
12
from torch.nn import functional as F
13
from torch.utils import model_zoo
14
15
########################################################################
16
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
17
########################################################################
18
19
20
# Parameters for the entire model (stem, all blocks, and head)
21
GlobalParams = collections.namedtuple('GlobalParams', [
22
    'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
23
    'num_classes', 'width_coefficient', 'depth_coefficient',
24
    'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
25
26
# Parameters for an individual model block
27
BlockArgs = collections.namedtuple('BlockArgs', [
28
    'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
29
    'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
30
31
# Change namedtuple defaults
32
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
33
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
34
35
36
class SwishImplementation(torch.autograd.Function):
37
    @staticmethod
38
    def forward(ctx, i):
39
        result = i * torch.sigmoid(i)
40
        ctx.save_for_backward(i)
41
        return result
42
43
    @staticmethod
44
    def backward(ctx, grad_output):
45
        i = ctx.saved_variables[0]
46
        sigmoid_i = torch.sigmoid(i)
47
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
48
49
50
class MemoryEfficientSwish(nn.Module):
51
    def forward(self, x):
52
        return SwishImplementation.apply(x)
53
54
55
class Swish(nn.Module):
56
    def forward(self, x):
57
        return x * torch.sigmoid(x)
58
59
60
def round_filters(filters, global_params):
61
    """ Calculate and round number of filters based on depth multiplier. """
62
    multiplier = global_params.width_coefficient
63
    if not multiplier:
64
        return filters
65
    divisor = global_params.depth_divisor
66
    min_depth = global_params.min_depth
67
    filters *= multiplier
68
    min_depth = min_depth or divisor
69
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
70
    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
71
        new_filters += divisor
72
    return int(new_filters)
73
74
75
def round_repeats(repeats, global_params):
76
    """ Round number of filters based on depth multiplier. """
77
    multiplier = global_params.depth_coefficient
78
    if not multiplier:
79
        return repeats
80
    return int(math.ceil(multiplier * repeats))
81
82
83
def drop_connect(inputs, p, training):
84
    """ Drop connect. """
85
    if not training: return inputs
86
    batch_size = inputs.shape[0]
87
    keep_prob = 1 - p
88
    random_tensor = keep_prob
89
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
90
    binary_tensor = torch.floor(random_tensor)
91
    output = inputs / keep_prob * binary_tensor
92
    return output
93
94
95
def get_same_padding_conv2d(image_size=None):
96
    """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
97
        Static padding is necessary for ONNX exporting of models. """
98
    if image_size is None:
99
        return Conv2dDynamicSamePadding
100
    else:
101
        return partial(Conv2dStaticSamePadding, image_size=image_size)
102
103
104
class Conv2dDynamicSamePadding(nn.Conv2d):
105
    """ 2D Convolutions like TensorFlow, for a dynamic image size """
106
107
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
108
        super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
109
        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
110
111
    def forward(self, x):
112
        ih, iw = x.size()[-2:]
113
        kh, kw = self.weight.size()[-2:]
114
        sh, sw = self.stride
115
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
116
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
117
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
118
        if pad_h > 0 or pad_w > 0:
119
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
120
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
121
122
123
class Conv2dStaticSamePadding(nn.Conv2d):
124
    """ 2D Convolutions like TensorFlow, for a fixed image size"""
125
126
    def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
127
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
128
        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
129
130
        # Calculate padding based on image size and save it
131
        assert image_size is not None
132
        ih, iw = image_size if type(image_size) == list else [image_size, image_size]
133
        kh, kw = self.weight.size()[-2:]
134
        sh, sw = self.stride
135
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
136
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
137
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
138
        if pad_h > 0 or pad_w > 0:
139
            self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
140
        else:
141
            self.static_padding = Identity()
142
143
    def forward(self, x):
144
        x = self.static_padding(x)
145
        x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
146
        return x
147
148
149
class Identity(nn.Module):
150
    def __init__(self, ):
151
        super(Identity, self).__init__()
152
153
    def forward(self, input):
154
        return input
155
156
157
########################################################################
158
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
159
########################################################################
160
161
162
def efficientnet_params(model_name):
163
    """ Map EfficientNet model name to parameter coefficients. """
164
    params_dict = {
165
        # Coefficients:   width,depth,res,dropout
166
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
167
        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
168
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
169
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
170
        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
171
        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
172
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
173
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
174
        'efficientnet-b8': (2.2, 3.6, 672, 0.5),
175
        'efficientnet-l2': (4.3, 5.3, 800, 0.5),
176
    }
177
    return params_dict[model_name]
178
179
180
class BlockDecoder(object):
181
    """ Block Decoder for readability, straight from the official TensorFlow repository """
182
183
    @staticmethod
184
    def _decode_block_string(block_string):
185
        """ Gets a block through a string notation of arguments. """
186
        assert isinstance(block_string, str)
187
188
        ops = block_string.split('_')
189
        options = {}
190
        for op in ops:
191
            splits = re.split(r'(\d.*)', op)
192
            if len(splits) >= 2:
193
                key, value = splits[:2]
194
                options[key] = value
195
196
        # Check stride
197
        assert (('s' in options and len(options['s']) == 1) or
198
                (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
199
200
        return BlockArgs(
201
            kernel_size=int(options['k']),
202
            num_repeat=int(options['r']),
203
            input_filters=int(options['i']),
204
            output_filters=int(options['o']),
205
            expand_ratio=int(options['e']),
206
            id_skip=('noskip' not in block_string),
207
            se_ratio=float(options['se']) if 'se' in options else None,
208
            stride=[int(options['s'][0])])
209
210
    @staticmethod
211
    def _encode_block_string(block):
212
        """Encodes a block to a string."""
213
        args = [
214
            'r%d' % block.num_repeat,
215
            'k%d' % block.kernel_size,
216
            's%d%d' % (block.strides[0], block.strides[1]),
217
            'e%s' % block.expand_ratio,
218
            'i%d' % block.input_filters,
219
            'o%d' % block.output_filters
220
        ]
221
        if 0 < block.se_ratio <= 1:
222
            args.append('se%s' % block.se_ratio)
223
        if block.id_skip is False:
224
            args.append('noskip')
225
        return '_'.join(args)
226
227
    @staticmethod
228
    def decode(string_list):
229
        """
230
        Decodes a list of string notations to specify blocks inside the network.
231
232
        :param string_list: a list of strings, each string is a notation of block
233
        :return: a list of BlockArgs namedtuples of block args
234
        """
235
        assert isinstance(string_list, list)
236
        blocks_args = []
237
        for block_string in string_list:
238
            blocks_args.append(BlockDecoder._decode_block_string(block_string))
239
        return blocks_args
240
241
    @staticmethod
242
    def encode(blocks_args):
243
        """
244
        Encodes a list of BlockArgs to a list of strings.
245
246
        :param blocks_args: a list of BlockArgs namedtuples of block args
247
        :return: a list of strings, each string is a notation of block
248
        """
249
        block_strings = []
250
        for block in blocks_args:
251
            block_strings.append(BlockDecoder._encode_block_string(block))
252
        return block_strings
253
254
255
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
256
                 drop_connect_rate=0.2, image_size=None, num_classes=1000):
257
    """ Creates a efficientnet model. """
258
259
    blocks_args = [
260
        'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
261
        'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
262
        'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
263
        'r1_k3_s11_e6_i192_o320_se0.25',
264
    ]
265
    blocks_args = BlockDecoder.decode(blocks_args)
266
267
    global_params = GlobalParams(
268
        batch_norm_momentum=0.99,
269
        batch_norm_epsilon=1e-3,
270
        dropout_rate=dropout_rate,
271
        drop_connect_rate=drop_connect_rate,
272
        # data_format='channels_last',  # removed, this is always true in PyTorch
273
        num_classes=num_classes,
274
        width_coefficient=width_coefficient,
275
        depth_coefficient=depth_coefficient,
276
        depth_divisor=8,
277
        min_depth=None,
278
        image_size=image_size,
279
    )
280
281
    return blocks_args, global_params
282
283
284
def get_model_params(model_name, override_params):
285
    """ Get the block args and global params for a given model """
286
    if model_name.startswith('efficientnet'):
287
        w, d, s, p = efficientnet_params(model_name)
288
        # note: all models have drop connect rate = 0.2
289
        blocks_args, global_params = efficientnet(
290
            width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
291
    else:
292
        raise NotImplementedError('model name is not pre-defined: %s' % model_name)
293
    if override_params:
294
        # ValueError will be raised here if override_params has fields not included in global_params.
295
        global_params = global_params._replace(**override_params)
296
    return blocks_args, global_params
297
298
299
url_map = {
300
    'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
301
    'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
302
    'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
303
    'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
304
    'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
305
    'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
306
    'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
307
    'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
308
}
309
310
311
url_map_advprop = {
312
    'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
313
    'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
314
    'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
315
    'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
316
    'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
317
    'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
318
    'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
319
    'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
320
    'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
321
}
322
323
324
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
325
    """ Loads pretrained weights, and downloads if loading for the first time. """
326
    # AutoAugment or Advprop (different preprocessing)
327
    url_map_ = url_map_advprop if advprop else url_map
328
    state_dict = model_zoo.load_url(url_map_[model_name])
329
    if load_fc:
330
        model.load_state_dict(state_dict)
331
    else:
332
        state_dict.pop('_fc.weight')
333
        state_dict.pop('_fc.bias')
334
        res = model.load_state_dict(state_dict, strict=False)
335
        assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
336
    print('Loaded pretrained weights for {}'.format(model_name))