a b/EfficientNet_2d/EfficientNet_2d.py
1
import torch
2
from torch import nn
3
from torch.nn import functional as F
4
from EfficientNet_2d.utils import (
5
    round_filters,
6
    round_repeats,
7
    drop_connect,
8
    get_same_padding_conv2d,
9
    get_model_params,
10
    efficientnet_params,
11
    load_pretrained_weights,
12
    Swish,
13
    MemoryEfficientSwish,
14
)
15
16
17
class MBConvBlock(nn.Module):
18
    """
19
    Mobile Inverted Residual Bottleneck Block
20
21
    Args:
22
        block_args (namedtuple): BlockArgs, see above
23
        global_params (namedtuple): GlobalParam, see above
24
25
    Attributes:
26
        has_se (bool): Whether the block contains a Squeeze and Excitation layer.
27
    """
28
29
    def __init__(self, block_args, global_params):
30
        super().__init__()
31
        self._block_args = block_args
32
        self._bn_mom = 1 - global_params.batch_norm_momentum
33
        self._bn_eps = global_params.batch_norm_epsilon
34
        self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
35
        self.id_skip = block_args.id_skip  # skip connection and drop connect
36
37
        # Get static or dynamic convolution depending on image size
38
        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
39
40
        # Expansion phase
41
        inp = self._block_args.input_filters  # number of input channels
42
        oup = self._block_args.input_filters * self._block_args.expand_ratio  # number of output channels
43
        if self._block_args.expand_ratio != 1:
44
            self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
45
            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
46
47
        # Depthwise convolution phase
48
        k = self._block_args.kernel_size
49
        s = self._block_args.stride
50
        self._depthwise_conv = Conv2d(
51
            in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
52
            kernel_size=k, stride=s, bias=False)
53
        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
54
55
        # Squeeze and Excitation layer, if desired
56
        if self.has_se:
57
            num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
58
            self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
59
            self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
60
61
        # Output phase
62
        final_oup = self._block_args.output_filters
63
        self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
64
        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
65
        self._swish = MemoryEfficientSwish()
66
67
    def forward(self, inputs, drop_connect_rate=None):
68
        """
69
        :param inputs: input tensor
70
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
71
        :return: output of block
72
        """
73
74
        # Expansion and Depthwise Convolution
75
        x = inputs
76
        if self._block_args.expand_ratio != 1:
77
            x = self._swish(self._bn0(self._expand_conv(inputs)))
78
        x = self._swish(self._bn1(self._depthwise_conv(x)))
79
80
        # Squeeze and Excitation
81
        if self.has_se:
82
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
83
            x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
84
            x = torch.sigmoid(x_squeezed) * x
85
86
        x = self._bn2(self._project_conv(x))
87
88
        # Skip connection and drop connect
89
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
90
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
91
            if drop_connect_rate:
92
                x = drop_connect(x, p=drop_connect_rate, training=self.training)
93
            x = x + inputs  # skip connection
94
        return x
95
96
    def set_swish(self, memory_efficient=True):
97
        """Sets swish function as memory efficient (for training) or standard (for export)"""
98
        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
99
100
101
class EfficientNet(nn.Module):
102
    """
103
    An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
104
105
    Args:
106
        blocks_args (list): A list of BlockArgs to construct blocks
107
        global_params (namedtuple): A set of GlobalParams shared between blocks
108
109
    Example:
110
        model = EfficientNet.from_pretrained('efficientnet-b0')
111
112
    """
113
    def __init__(self, blocks_args=None, global_params=None):
114
        super().__init__()
115
        assert isinstance(blocks_args, list), 'blocks_args should be a list'
116
        assert len(blocks_args) > 0, 'block args must be greater than 0'
117
        self._global_params = global_params
118
        self._blocks_args = blocks_args
119
120
        # Get static or dynamic convolution depending on image size
121
        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
122
123
        # Batch norm parameters
124
        bn_mom = 1 - self._global_params.batch_norm_momentum
125
        bn_eps = self._global_params.batch_norm_epsilon
126
127
        # Stem
128
        in_channels = 3  # rgb
129
        out_channels = round_filters(32, self._global_params)  # number of output channels
130
        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
131
        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
132
133
        # Build blocks
134
        self._blocks = nn.ModuleList([])
135
        for block_args in self._blocks_args:
136
137
            # Update block input and output filters based on depth multiplier.
138
            block_args = block_args._replace(
139
                input_filters=round_filters(block_args.input_filters, self._global_params),
140
                output_filters=round_filters(block_args.output_filters, self._global_params),
141
                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
142
            )
143
144
            # The first block needs to take care of stride and filter size increase.
145
            self._blocks.append(MBConvBlock(block_args, self._global_params))
146
            if block_args.num_repeat > 1:
147
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
148
            for _ in range(block_args.num_repeat - 1):
149
                self._blocks.append(MBConvBlock(block_args, self._global_params))
150
151
        # Head
152
        in_channels = block_args.output_filters  # output of final block
153
        out_channels = round_filters(1280, self._global_params)
154
        self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
155
        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
156
157
        # Final linear layer
158
        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
159
        self._dropout = nn.Dropout(self._global_params.dropout_rate)
160
        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
161
        self._swish = MemoryEfficientSwish()
162
163
    def set_swish(self, memory_efficient=True):
164
        """Sets swish function as memory efficient (for training) or standard (for export)"""
165
        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
166
        for block in self._blocks:
167
            block.set_swish(memory_efficient)
168
169
    def extract_features(self, inputs):
170
        """ Returns output of the final convolution layer """
171
        # Stem
172
        x = self._swish(self._bn0(self._conv_stem(inputs)))
173
174
        # Blocks
175
        for idx, block in enumerate(self._blocks):
176
            drop_connect_rate = self._global_params.drop_connect_rate
177
            if drop_connect_rate:
178
                drop_connect_rate *= float(idx) / len(self._blocks)
179
            x = block(x, drop_connect_rate=drop_connect_rate)
180
181
        # Head
182
        x = self._swish(self._bn1(self._conv_head(x)))
183
184
        return x
185
186
    def forward(self, inputs):
187
        """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
188
        bs = inputs.size(0)
189
        # Convolution layers
190
        x = self.extract_features(inputs)
191
        # Pooling and final linear layer
192
        x = self._avg_pooling(x)
193
        x = x.view(bs, -1)
194
        x = self._dropout(x)
195
        x = self._fc(x)
196
        return x
197
198
    @classmethod
199
    def from_name(cls, model_name, override_params=None):
200
        cls._check_model_name_is_valid(model_name)
201
        blocks_args, global_params = get_model_params(model_name, override_params)
202
        return cls(blocks_args, global_params)
203
204
    @classmethod
205
    def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
206
        model = cls.from_name(model_name, override_params={'num_classes': num_classes})
207
        load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
208
        if in_channels != 3:
209
            Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
210
            out_channels = round_filters(32, model._global_params)
211
            model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
212
        return model
213
    
214
    @classmethod
215
    def get_image_size(cls, model_name):
216
        cls._check_model_name_is_valid(model_name)
217
        _, _, res, _ = efficientnet_params(model_name)
218
        return res
219
220
    @classmethod
221
    def _check_model_name_is_valid(cls, model_name):
222
        """ Validates model name. """ 
223
        valid_models = ['efficientnet-b'+str(i) for i in range(9)]
224
        if model_name not in valid_models:
225
            raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
226
227
228
# get pretrained EfficientNet for k-classes classification
229
def get_pretrained_EfficientNet(num_classes):
230
    model = EfficientNet.from_pretrained('efficientnet-b0')
231
    fc_features = model._fc.in_features
232
    model._fc = nn.Linear(fc_features, num_classes)
233
    return model
234
235
236
class DAR_Effi(nn.Module):
237
    def __init__(self, blocks_args=None, global_params=None, in_channels=3, att_start=11):
238
        super(DAR_Effi, self).__init__()
239
        assert isinstance(blocks_args, list), 'blocks_args should be a list'
240
        assert len(blocks_args) > 0, 'block args must be greater than 0'
241
        self._global_params = global_params
242
        self._blocks_args = blocks_args
243
        self.att_start = att_start  # for CA-module and NA-module
244
245
        # Get static or dynamic convolution depending on image size
246
        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
247
248
        # Batch norm parameters
249
        bn_mom = 1 - self._global_params.batch_norm_momentum
250
        bn_eps = self._global_params.batch_norm_epsilon
251
252
        # Stem
253
        out_channels = round_filters(32, self._global_params)  # number of output channels
254
        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
255
        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
256
        self._conv_stem_cf = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
257
        self._bn0_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
258
        self._conv_stem_lr = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
259
        self._bn0_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
260
261
        # Build blocks of Prd-Net
262
        self._blocks = nn.ModuleList([])
263
        for block_args in self._blocks_args:
264
265
            # Update block input and output filters based on depth multiplier.
266
            block_args = block_args._replace(
267
                input_filters=round_filters(block_args.input_filters, self._global_params),
268
                output_filters=round_filters(block_args.output_filters, self._global_params),
269
                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
270
            )
271
272
            # The first block needs to take care of stride and filter size increase.
273
            self._blocks.append(MBConvBlock(block_args, self._global_params))
274
            if block_args.num_repeat > 1:
275
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
276
            for _ in range(block_args.num_repeat - 1):
277
                self._blocks.append(MBConvBlock(block_args, self._global_params))
278
279
        # Build blocks of CF-Net
280
        self._blocks_cf = nn.ModuleList([])
281
        for block_args in self._blocks_args:
282
283
            # Update block input and output filters based on depth multiplier.
284
            block_args = block_args._replace(
285
                input_filters=round_filters(block_args.input_filters, self._global_params),
286
                output_filters=round_filters(block_args.output_filters, self._global_params),
287
                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
288
            )
289
290
            # The first block needs to take care of stride and filter size increase.
291
            self._blocks_cf.append(MBConvBlock(block_args, self._global_params))
292
            if block_args.num_repeat > 1:
293
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
294
            for _ in range(block_args.num_repeat - 1):
295
                self._blocks_cf.append(MBConvBlock(block_args, self._global_params))
296
297
        # Build blocks of LR-Net
298
        self._blocks_lr = nn.ModuleList([])
299
        for block_args in self._blocks_args:
300
301
            # Update block input and output filters based on depth multiplier.
302
            block_args = block_args._replace(
303
                input_filters=round_filters(block_args.input_filters, self._global_params),
304
                output_filters=round_filters(block_args.output_filters, self._global_params),
305
                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
306
            )
307
308
            # The first block needs to take care of stride and filter size increase.
309
            self._blocks_lr.append(MBConvBlock(block_args, self._global_params))
310
            if block_args.num_repeat > 1:
311
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
312
            for _ in range(block_args.num_repeat - 1):
313
                self._blocks_lr.append(MBConvBlock(block_args, self._global_params))
314
315
        # Head
316
        in_channels = block_args.output_filters  # output of final block
317
        out_channels = round_filters(1280, self._global_params)
318
        self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
319
        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
320
        self._conv_head_cf = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
321
        self._bn1_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
322
        self._conv_head_lr = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
323
        self._bn1_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
324
325
        # Final linear layer
326
        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
327
        self._dropout = nn.Dropout(self._global_params.dropout_rate)
328
        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
329
        self._swish = MemoryEfficientSwish()
330
331
        self._avg_pooling_cf = nn.AdaptiveAvgPool2d(1)
332
        self._dropout_cf = nn.Dropout(self._global_params.dropout_rate)
333
        self._fc_cf = nn.Linear(out_channels, self._global_params.num_classes)
334
        self._swish_cf = MemoryEfficientSwish()
335
336
        self._avg_pooling_lr = nn.AdaptiveAvgPool2d(1)
337
        self._dropout_lr = nn.Dropout(self._global_params.dropout_rate)
338
        self._fc_lr = nn.Linear(out_channels, self._global_params.num_classes)
339
        self._swish_lr = MemoryEfficientSwish()
340
341
    def set_swish(self, memory_efficient=True):
342
        """Sets swish function as memory efficient (for training) or standard (for export)"""
343
        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
344
        for block in self._blocks:
345
            block.set_swish(memory_efficient)
346
347
        self._swish_cf = MemoryEfficientSwish() if memory_efficient else Swish()
348
        for block_cf in self._blocks_cf:
349
            block_cf.set_swish(memory_efficient)
350
351
        self._swish_lr = MemoryEfficientSwish() if memory_efficient else Swish()
352
        for block_lr in self._blocks_lr:
353
            block_lr.set_swish(memory_efficient)
354
355
    def attention(self, f_prd, f_cf, f_lr):
356
        w_cf = 1 - torch.sigmoid(f_cf)
357
        add_cf = w_cf * f_prd
358
359
        w_lr = 1 - abs(torch.sigmoid(f_prd)-torch.sigmoid(f_lr))
360
        add_lr = w_lr * f_prd
361
362
        f_prd = f_prd + add_cf + add_lr
363
        return f_prd
364
365
    def extract_features(self, inputs):
366
        """ Returns output of the final convolution layer """
367
368
        # Stem
369
        x = self._swish(self._bn0(self._conv_stem(inputs)))
370
        x_cf = self._swish_cf(self._bn0_cf(self._conv_stem_cf(inputs)))
371
        x_lr = self._swish_lr(self._bn0_lr(self._conv_stem_lr(inputs)))
372
373
        # Blocks
374
        for idx, block in enumerate(self._blocks):
375
            block_cf = self._blocks_cf[idx]
376
            block_lr = self._blocks_lr[idx]
377
378
            drop_connect_rate = self._global_params.drop_connect_rate
379
            if drop_connect_rate:
380
                drop_connect_rate *= float(idx) / len(self._blocks)
381
382
            x = block(x, drop_connect_rate=drop_connect_rate)
383
            x_cf = block_cf(x_cf, drop_connect_rate=drop_connect_rate)
384
            x_lr = block_lr(x_lr, drop_connect_rate=drop_connect_rate)
385
386
            if idx >= self.att_start:
387
                x = self.attention(x, x_cf, x_lr)
388
389
        # Head
390
        x = self._swish(self._bn1(self._conv_head(x)))
391
        x_cf = self._swish_cf(self._bn1_cf(self._conv_head_cf(x_cf)))
392
        x_lr = self._swish_lr(self._bn1_lr(self._conv_head_lr(x_lr)))
393
394
        return x, x_cf, x_lr
395
396
    def forward(self, inputs):
397
        bs = inputs.size(0)
398
        # Convolution layers
399
        x, x_cf, x_lr = self.extract_features(inputs)
400
401
        # Pooling and final linear layer
402
        x = self._avg_pooling(x)
403
        x = x.view(bs, -1)
404
        x = self._dropout(x)
405
        x = self._fc(x)
406
407
        x_cf = self._avg_pooling_cf(x_cf)
408
        x_cf = x_cf.view(bs, -1)
409
        x_cf = self._dropout_cf(x_cf)
410
        x_cf = self._fc_cf(x_cf)
411
412
        x_lr = self._avg_pooling_lr(x_lr)
413
        x_lr = x_lr.view(bs, -1)
414
        x_lr = self._dropout_lr(x_lr)
415
        x_lr = self._fc_lr(x_lr)
416
417
        return x, x_cf, x_lr
418
419
    @classmethod
420
    def from_name(cls, model_name, override_params=None, in_channels=3, att_start=11):
421
        cls._check_model_name_is_valid(model_name)
422
        blocks_args, global_params = get_model_params(model_name, override_params)
423
        return cls(blocks_args, global_params, in_channels, att_start)
424
425
    @classmethod
426
    def get_image_size(cls, model_name):
427
        cls._check_model_name_is_valid(model_name)
428
        _, _, res, _ = efficientnet_params(model_name)
429
        return res
430
431
    @classmethod
432
    def _check_model_name_is_valid(cls, model_name):
433
        """ Validates model name. """
434
        valid_models = ['efficientnet-b'+str(i) for i in range(9)]
435
        if model_name not in valid_models:
436
            raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
437
438
439
def get_pretrained_DAR(prd_params, cf_params, lr_params, num_classes):
440
441
    dar_model = DAR_Effi.from_name('efficientnet-b0')
442
    fc_features = dar_model._fc.in_features
443
    dar_model._fc = nn.Linear(fc_features, num_classes)
444
    dar_model._fc_cf = nn.Linear(fc_features, num_classes)
445
    dar_model._fc_lr = nn.Linear(fc_features, num_classes)
446
    dar_params = dar_model.state_dict()
447
448
    for k, v in prd_params.items():
449
        index_point = k.find('.')
450
        k_apart = k[0:index_point]
451
        k_bpart = k[index_point:len(k)]
452
        k_cf = k_apart + '_cf' + k_bpart
453
        k_lr = k_apart + '_lr' + k_bpart
454
455
        dar_params[k] = prd_params[k]
456
        dar_params[k_cf] = cf_params[k]
457
        dar_params[k_lr] = lr_params[k]
458
459
    dar_model.load_state_dict(dar_params)
460
    return dar_model