|
a |
|
b/EfficientNet_2d/utils.py |
|
|
1 |
""" |
|
|
2 |
This file contains helper functions for building the model and for loading model parameters. |
|
|
3 |
These helper functions are built to mirror those in the official TensorFlow implementation. |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import re |
|
|
7 |
import math |
|
|
8 |
import collections |
|
|
9 |
from functools import partial |
|
|
10 |
import torch |
|
|
11 |
from torch import nn |
|
|
12 |
from torch.nn import functional as F |
|
|
13 |
from torch.utils import model_zoo |
|
|
14 |
|
|
|
15 |
######################################################################## |
|
|
16 |
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### |
|
|
17 |
######################################################################## |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
# Parameters for the entire model (stem, all blocks, and head) |
|
|
21 |
GlobalParams = collections.namedtuple('GlobalParams', [ |
|
|
22 |
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', |
|
|
23 |
'num_classes', 'width_coefficient', 'depth_coefficient', |
|
|
24 |
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) |
|
|
25 |
|
|
|
26 |
# Parameters for an individual model block |
|
|
27 |
BlockArgs = collections.namedtuple('BlockArgs', [ |
|
|
28 |
'kernel_size', 'num_repeat', 'input_filters', 'output_filters', |
|
|
29 |
'expand_ratio', 'id_skip', 'stride', 'se_ratio']) |
|
|
30 |
|
|
|
31 |
# Change namedtuple defaults |
|
|
32 |
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) |
|
|
33 |
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class SwishImplementation(torch.autograd.Function): |
|
|
37 |
@staticmethod |
|
|
38 |
def forward(ctx, i): |
|
|
39 |
result = i * torch.sigmoid(i) |
|
|
40 |
ctx.save_for_backward(i) |
|
|
41 |
return result |
|
|
42 |
|
|
|
43 |
@staticmethod |
|
|
44 |
def backward(ctx, grad_output): |
|
|
45 |
i = ctx.saved_variables[0] |
|
|
46 |
sigmoid_i = torch.sigmoid(i) |
|
|
47 |
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
class MemoryEfficientSwish(nn.Module): |
|
|
51 |
def forward(self, x): |
|
|
52 |
return SwishImplementation.apply(x) |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
class Swish(nn.Module): |
|
|
56 |
def forward(self, x): |
|
|
57 |
return x * torch.sigmoid(x) |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def round_filters(filters, global_params): |
|
|
61 |
""" Calculate and round number of filters based on depth multiplier. """ |
|
|
62 |
multiplier = global_params.width_coefficient |
|
|
63 |
if not multiplier: |
|
|
64 |
return filters |
|
|
65 |
divisor = global_params.depth_divisor |
|
|
66 |
min_depth = global_params.min_depth |
|
|
67 |
filters *= multiplier |
|
|
68 |
min_depth = min_depth or divisor |
|
|
69 |
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) |
|
|
70 |
if new_filters < 0.9 * filters: # prevent rounding by more than 10% |
|
|
71 |
new_filters += divisor |
|
|
72 |
return int(new_filters) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def round_repeats(repeats, global_params): |
|
|
76 |
""" Round number of filters based on depth multiplier. """ |
|
|
77 |
multiplier = global_params.depth_coefficient |
|
|
78 |
if not multiplier: |
|
|
79 |
return repeats |
|
|
80 |
return int(math.ceil(multiplier * repeats)) |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
def drop_connect(inputs, p, training): |
|
|
84 |
""" Drop connect. """ |
|
|
85 |
if not training: return inputs |
|
|
86 |
batch_size = inputs.shape[0] |
|
|
87 |
keep_prob = 1 - p |
|
|
88 |
random_tensor = keep_prob |
|
|
89 |
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) |
|
|
90 |
binary_tensor = torch.floor(random_tensor) |
|
|
91 |
output = inputs / keep_prob * binary_tensor |
|
|
92 |
return output |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
def get_same_padding_conv2d(image_size=None): |
|
|
96 |
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise. |
|
|
97 |
Static padding is necessary for ONNX exporting of models. """ |
|
|
98 |
if image_size is None: |
|
|
99 |
return Conv2dDynamicSamePadding |
|
|
100 |
else: |
|
|
101 |
return partial(Conv2dStaticSamePadding, image_size=image_size) |
|
|
102 |
|
|
|
103 |
|
|
|
104 |
class Conv2dDynamicSamePadding(nn.Conv2d): |
|
|
105 |
""" 2D Convolutions like TensorFlow, for a dynamic image size """ |
|
|
106 |
|
|
|
107 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): |
|
|
108 |
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) |
|
|
109 |
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 |
|
|
110 |
|
|
|
111 |
def forward(self, x): |
|
|
112 |
ih, iw = x.size()[-2:] |
|
|
113 |
kh, kw = self.weight.size()[-2:] |
|
|
114 |
sh, sw = self.stride |
|
|
115 |
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) |
|
|
116 |
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) |
|
|
117 |
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) |
|
|
118 |
if pad_h > 0 or pad_w > 0: |
|
|
119 |
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) |
|
|
120 |
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
class Conv2dStaticSamePadding(nn.Conv2d): |
|
|
124 |
""" 2D Convolutions like TensorFlow, for a fixed image size""" |
|
|
125 |
|
|
|
126 |
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): |
|
|
127 |
super().__init__(in_channels, out_channels, kernel_size, **kwargs) |
|
|
128 |
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 |
|
|
129 |
|
|
|
130 |
# Calculate padding based on image size and save it |
|
|
131 |
assert image_size is not None |
|
|
132 |
ih, iw = image_size if type(image_size) == list else [image_size, image_size] |
|
|
133 |
kh, kw = self.weight.size()[-2:] |
|
|
134 |
sh, sw = self.stride |
|
|
135 |
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) |
|
|
136 |
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) |
|
|
137 |
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) |
|
|
138 |
if pad_h > 0 or pad_w > 0: |
|
|
139 |
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) |
|
|
140 |
else: |
|
|
141 |
self.static_padding = Identity() |
|
|
142 |
|
|
|
143 |
def forward(self, x): |
|
|
144 |
x = self.static_padding(x) |
|
|
145 |
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
146 |
return x |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
class Identity(nn.Module): |
|
|
150 |
def __init__(self, ): |
|
|
151 |
super(Identity, self).__init__() |
|
|
152 |
|
|
|
153 |
def forward(self, input): |
|
|
154 |
return input |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
######################################################################## |
|
|
158 |
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## |
|
|
159 |
######################################################################## |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def efficientnet_params(model_name): |
|
|
163 |
""" Map EfficientNet model name to parameter coefficients. """ |
|
|
164 |
params_dict = { |
|
|
165 |
# Coefficients: width,depth,res,dropout |
|
|
166 |
'efficientnet-b0': (1.0, 1.0, 224, 0.2), |
|
|
167 |
'efficientnet-b1': (1.0, 1.1, 240, 0.2), |
|
|
168 |
'efficientnet-b2': (1.1, 1.2, 260, 0.3), |
|
|
169 |
'efficientnet-b3': (1.2, 1.4, 300, 0.3), |
|
|
170 |
'efficientnet-b4': (1.4, 1.8, 380, 0.4), |
|
|
171 |
'efficientnet-b5': (1.6, 2.2, 456, 0.4), |
|
|
172 |
'efficientnet-b6': (1.8, 2.6, 528, 0.5), |
|
|
173 |
'efficientnet-b7': (2.0, 3.1, 600, 0.5), |
|
|
174 |
'efficientnet-b8': (2.2, 3.6, 672, 0.5), |
|
|
175 |
'efficientnet-l2': (4.3, 5.3, 800, 0.5), |
|
|
176 |
} |
|
|
177 |
return params_dict[model_name] |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
class BlockDecoder(object): |
|
|
181 |
""" Block Decoder for readability, straight from the official TensorFlow repository """ |
|
|
182 |
|
|
|
183 |
@staticmethod |
|
|
184 |
def _decode_block_string(block_string): |
|
|
185 |
""" Gets a block through a string notation of arguments. """ |
|
|
186 |
assert isinstance(block_string, str) |
|
|
187 |
|
|
|
188 |
ops = block_string.split('_') |
|
|
189 |
options = {} |
|
|
190 |
for op in ops: |
|
|
191 |
splits = re.split(r'(\d.*)', op) |
|
|
192 |
if len(splits) >= 2: |
|
|
193 |
key, value = splits[:2] |
|
|
194 |
options[key] = value |
|
|
195 |
|
|
|
196 |
# Check stride |
|
|
197 |
assert (('s' in options and len(options['s']) == 1) or |
|
|
198 |
(len(options['s']) == 2 and options['s'][0] == options['s'][1])) |
|
|
199 |
|
|
|
200 |
return BlockArgs( |
|
|
201 |
kernel_size=int(options['k']), |
|
|
202 |
num_repeat=int(options['r']), |
|
|
203 |
input_filters=int(options['i']), |
|
|
204 |
output_filters=int(options['o']), |
|
|
205 |
expand_ratio=int(options['e']), |
|
|
206 |
id_skip=('noskip' not in block_string), |
|
|
207 |
se_ratio=float(options['se']) if 'se' in options else None, |
|
|
208 |
stride=[int(options['s'][0])]) |
|
|
209 |
|
|
|
210 |
@staticmethod |
|
|
211 |
def _encode_block_string(block): |
|
|
212 |
"""Encodes a block to a string.""" |
|
|
213 |
args = [ |
|
|
214 |
'r%d' % block.num_repeat, |
|
|
215 |
'k%d' % block.kernel_size, |
|
|
216 |
's%d%d' % (block.strides[0], block.strides[1]), |
|
|
217 |
'e%s' % block.expand_ratio, |
|
|
218 |
'i%d' % block.input_filters, |
|
|
219 |
'o%d' % block.output_filters |
|
|
220 |
] |
|
|
221 |
if 0 < block.se_ratio <= 1: |
|
|
222 |
args.append('se%s' % block.se_ratio) |
|
|
223 |
if block.id_skip is False: |
|
|
224 |
args.append('noskip') |
|
|
225 |
return '_'.join(args) |
|
|
226 |
|
|
|
227 |
@staticmethod |
|
|
228 |
def decode(string_list): |
|
|
229 |
""" |
|
|
230 |
Decodes a list of string notations to specify blocks inside the network. |
|
|
231 |
|
|
|
232 |
:param string_list: a list of strings, each string is a notation of block |
|
|
233 |
:return: a list of BlockArgs namedtuples of block args |
|
|
234 |
""" |
|
|
235 |
assert isinstance(string_list, list) |
|
|
236 |
blocks_args = [] |
|
|
237 |
for block_string in string_list: |
|
|
238 |
blocks_args.append(BlockDecoder._decode_block_string(block_string)) |
|
|
239 |
return blocks_args |
|
|
240 |
|
|
|
241 |
@staticmethod |
|
|
242 |
def encode(blocks_args): |
|
|
243 |
""" |
|
|
244 |
Encodes a list of BlockArgs to a list of strings. |
|
|
245 |
|
|
|
246 |
:param blocks_args: a list of BlockArgs namedtuples of block args |
|
|
247 |
:return: a list of strings, each string is a notation of block |
|
|
248 |
""" |
|
|
249 |
block_strings = [] |
|
|
250 |
for block in blocks_args: |
|
|
251 |
block_strings.append(BlockDecoder._encode_block_string(block)) |
|
|
252 |
return block_strings |
|
|
253 |
|
|
|
254 |
|
|
|
255 |
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, |
|
|
256 |
drop_connect_rate=0.2, image_size=None, num_classes=1000): |
|
|
257 |
""" Creates a efficientnet model. """ |
|
|
258 |
|
|
|
259 |
blocks_args = [ |
|
|
260 |
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', |
|
|
261 |
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', |
|
|
262 |
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', |
|
|
263 |
'r1_k3_s11_e6_i192_o320_se0.25', |
|
|
264 |
] |
|
|
265 |
blocks_args = BlockDecoder.decode(blocks_args) |
|
|
266 |
|
|
|
267 |
global_params = GlobalParams( |
|
|
268 |
batch_norm_momentum=0.99, |
|
|
269 |
batch_norm_epsilon=1e-3, |
|
|
270 |
dropout_rate=dropout_rate, |
|
|
271 |
drop_connect_rate=drop_connect_rate, |
|
|
272 |
# data_format='channels_last', # removed, this is always true in PyTorch |
|
|
273 |
num_classes=num_classes, |
|
|
274 |
width_coefficient=width_coefficient, |
|
|
275 |
depth_coefficient=depth_coefficient, |
|
|
276 |
depth_divisor=8, |
|
|
277 |
min_depth=None, |
|
|
278 |
image_size=image_size, |
|
|
279 |
) |
|
|
280 |
|
|
|
281 |
return blocks_args, global_params |
|
|
282 |
|
|
|
283 |
|
|
|
284 |
def get_model_params(model_name, override_params): |
|
|
285 |
""" Get the block args and global params for a given model """ |
|
|
286 |
if model_name.startswith('efficientnet'): |
|
|
287 |
w, d, s, p = efficientnet_params(model_name) |
|
|
288 |
# note: all models have drop connect rate = 0.2 |
|
|
289 |
blocks_args, global_params = efficientnet( |
|
|
290 |
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) |
|
|
291 |
else: |
|
|
292 |
raise NotImplementedError('model name is not pre-defined: %s' % model_name) |
|
|
293 |
if override_params: |
|
|
294 |
# ValueError will be raised here if override_params has fields not included in global_params. |
|
|
295 |
global_params = global_params._replace(**override_params) |
|
|
296 |
return blocks_args, global_params |
|
|
297 |
|
|
|
298 |
|
|
|
299 |
url_map = { |
|
|
300 |
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', |
|
|
301 |
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', |
|
|
302 |
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', |
|
|
303 |
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', |
|
|
304 |
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', |
|
|
305 |
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', |
|
|
306 |
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', |
|
|
307 |
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', |
|
|
308 |
} |
|
|
309 |
|
|
|
310 |
|
|
|
311 |
url_map_advprop = { |
|
|
312 |
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', |
|
|
313 |
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', |
|
|
314 |
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', |
|
|
315 |
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', |
|
|
316 |
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', |
|
|
317 |
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', |
|
|
318 |
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', |
|
|
319 |
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', |
|
|
320 |
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', |
|
|
321 |
} |
|
|
322 |
|
|
|
323 |
|
|
|
324 |
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): |
|
|
325 |
""" Loads pretrained weights, and downloads if loading for the first time. """ |
|
|
326 |
# AutoAugment or Advprop (different preprocessing) |
|
|
327 |
url_map_ = url_map_advprop if advprop else url_map |
|
|
328 |
state_dict = model_zoo.load_url(url_map_[model_name]) |
|
|
329 |
if load_fc: |
|
|
330 |
model.load_state_dict(state_dict) |
|
|
331 |
else: |
|
|
332 |
state_dict.pop('_fc.weight') |
|
|
333 |
state_dict.pop('_fc.bias') |
|
|
334 |
res = model.load_state_dict(state_dict, strict=False) |
|
|
335 |
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' |
|
|
336 |
print('Loaded pretrained weights for {}'.format(model_name)) |