|
a |
|
b/EfficientNet_2d/EfficientNet_2d.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
from torch.nn import functional as F |
|
|
4 |
from EfficientNet_2d.utils import ( |
|
|
5 |
round_filters, |
|
|
6 |
round_repeats, |
|
|
7 |
drop_connect, |
|
|
8 |
get_same_padding_conv2d, |
|
|
9 |
get_model_params, |
|
|
10 |
efficientnet_params, |
|
|
11 |
load_pretrained_weights, |
|
|
12 |
Swish, |
|
|
13 |
MemoryEfficientSwish, |
|
|
14 |
) |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
class MBConvBlock(nn.Module): |
|
|
18 |
""" |
|
|
19 |
Mobile Inverted Residual Bottleneck Block |
|
|
20 |
|
|
|
21 |
Args: |
|
|
22 |
block_args (namedtuple): BlockArgs, see above |
|
|
23 |
global_params (namedtuple): GlobalParam, see above |
|
|
24 |
|
|
|
25 |
Attributes: |
|
|
26 |
has_se (bool): Whether the block contains a Squeeze and Excitation layer. |
|
|
27 |
""" |
|
|
28 |
|
|
|
29 |
def __init__(self, block_args, global_params): |
|
|
30 |
super().__init__() |
|
|
31 |
self._block_args = block_args |
|
|
32 |
self._bn_mom = 1 - global_params.batch_norm_momentum |
|
|
33 |
self._bn_eps = global_params.batch_norm_epsilon |
|
|
34 |
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) |
|
|
35 |
self.id_skip = block_args.id_skip # skip connection and drop connect |
|
|
36 |
|
|
|
37 |
# Get static or dynamic convolution depending on image size |
|
|
38 |
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) |
|
|
39 |
|
|
|
40 |
# Expansion phase |
|
|
41 |
inp = self._block_args.input_filters # number of input channels |
|
|
42 |
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels |
|
|
43 |
if self._block_args.expand_ratio != 1: |
|
|
44 |
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) |
|
|
45 |
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) |
|
|
46 |
|
|
|
47 |
# Depthwise convolution phase |
|
|
48 |
k = self._block_args.kernel_size |
|
|
49 |
s = self._block_args.stride |
|
|
50 |
self._depthwise_conv = Conv2d( |
|
|
51 |
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise |
|
|
52 |
kernel_size=k, stride=s, bias=False) |
|
|
53 |
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) |
|
|
54 |
|
|
|
55 |
# Squeeze and Excitation layer, if desired |
|
|
56 |
if self.has_se: |
|
|
57 |
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) |
|
|
58 |
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) |
|
|
59 |
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) |
|
|
60 |
|
|
|
61 |
# Output phase |
|
|
62 |
final_oup = self._block_args.output_filters |
|
|
63 |
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) |
|
|
64 |
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) |
|
|
65 |
self._swish = MemoryEfficientSwish() |
|
|
66 |
|
|
|
67 |
def forward(self, inputs, drop_connect_rate=None): |
|
|
68 |
""" |
|
|
69 |
:param inputs: input tensor |
|
|
70 |
:param drop_connect_rate: drop connect rate (float, between 0 and 1) |
|
|
71 |
:return: output of block |
|
|
72 |
""" |
|
|
73 |
|
|
|
74 |
# Expansion and Depthwise Convolution |
|
|
75 |
x = inputs |
|
|
76 |
if self._block_args.expand_ratio != 1: |
|
|
77 |
x = self._swish(self._bn0(self._expand_conv(inputs))) |
|
|
78 |
x = self._swish(self._bn1(self._depthwise_conv(x))) |
|
|
79 |
|
|
|
80 |
# Squeeze and Excitation |
|
|
81 |
if self.has_se: |
|
|
82 |
x_squeezed = F.adaptive_avg_pool2d(x, 1) |
|
|
83 |
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) |
|
|
84 |
x = torch.sigmoid(x_squeezed) * x |
|
|
85 |
|
|
|
86 |
x = self._bn2(self._project_conv(x)) |
|
|
87 |
|
|
|
88 |
# Skip connection and drop connect |
|
|
89 |
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters |
|
|
90 |
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: |
|
|
91 |
if drop_connect_rate: |
|
|
92 |
x = drop_connect(x, p=drop_connect_rate, training=self.training) |
|
|
93 |
x = x + inputs # skip connection |
|
|
94 |
return x |
|
|
95 |
|
|
|
96 |
def set_swish(self, memory_efficient=True): |
|
|
97 |
"""Sets swish function as memory efficient (for training) or standard (for export)""" |
|
|
98 |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
class EfficientNet(nn.Module): |
|
|
102 |
""" |
|
|
103 |
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods |
|
|
104 |
|
|
|
105 |
Args: |
|
|
106 |
blocks_args (list): A list of BlockArgs to construct blocks |
|
|
107 |
global_params (namedtuple): A set of GlobalParams shared between blocks |
|
|
108 |
|
|
|
109 |
Example: |
|
|
110 |
model = EfficientNet.from_pretrained('efficientnet-b0') |
|
|
111 |
|
|
|
112 |
""" |
|
|
113 |
def __init__(self, blocks_args=None, global_params=None): |
|
|
114 |
super().__init__() |
|
|
115 |
assert isinstance(blocks_args, list), 'blocks_args should be a list' |
|
|
116 |
assert len(blocks_args) > 0, 'block args must be greater than 0' |
|
|
117 |
self._global_params = global_params |
|
|
118 |
self._blocks_args = blocks_args |
|
|
119 |
|
|
|
120 |
# Get static or dynamic convolution depending on image size |
|
|
121 |
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) |
|
|
122 |
|
|
|
123 |
# Batch norm parameters |
|
|
124 |
bn_mom = 1 - self._global_params.batch_norm_momentum |
|
|
125 |
bn_eps = self._global_params.batch_norm_epsilon |
|
|
126 |
|
|
|
127 |
# Stem |
|
|
128 |
in_channels = 3 # rgb |
|
|
129 |
out_channels = round_filters(32, self._global_params) # number of output channels |
|
|
130 |
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
131 |
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
132 |
|
|
|
133 |
# Build blocks |
|
|
134 |
self._blocks = nn.ModuleList([]) |
|
|
135 |
for block_args in self._blocks_args: |
|
|
136 |
|
|
|
137 |
# Update block input and output filters based on depth multiplier. |
|
|
138 |
block_args = block_args._replace( |
|
|
139 |
input_filters=round_filters(block_args.input_filters, self._global_params), |
|
|
140 |
output_filters=round_filters(block_args.output_filters, self._global_params), |
|
|
141 |
num_repeat=round_repeats(block_args.num_repeat, self._global_params) |
|
|
142 |
) |
|
|
143 |
|
|
|
144 |
# The first block needs to take care of stride and filter size increase. |
|
|
145 |
self._blocks.append(MBConvBlock(block_args, self._global_params)) |
|
|
146 |
if block_args.num_repeat > 1: |
|
|
147 |
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) |
|
|
148 |
for _ in range(block_args.num_repeat - 1): |
|
|
149 |
self._blocks.append(MBConvBlock(block_args, self._global_params)) |
|
|
150 |
|
|
|
151 |
# Head |
|
|
152 |
in_channels = block_args.output_filters # output of final block |
|
|
153 |
out_channels = round_filters(1280, self._global_params) |
|
|
154 |
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
155 |
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
156 |
|
|
|
157 |
# Final linear layer |
|
|
158 |
self._avg_pooling = nn.AdaptiveAvgPool2d(1) |
|
|
159 |
self._dropout = nn.Dropout(self._global_params.dropout_rate) |
|
|
160 |
self._fc = nn.Linear(out_channels, self._global_params.num_classes) |
|
|
161 |
self._swish = MemoryEfficientSwish() |
|
|
162 |
|
|
|
163 |
def set_swish(self, memory_efficient=True): |
|
|
164 |
"""Sets swish function as memory efficient (for training) or standard (for export)""" |
|
|
165 |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
166 |
for block in self._blocks: |
|
|
167 |
block.set_swish(memory_efficient) |
|
|
168 |
|
|
|
169 |
def extract_features(self, inputs): |
|
|
170 |
""" Returns output of the final convolution layer """ |
|
|
171 |
# Stem |
|
|
172 |
x = self._swish(self._bn0(self._conv_stem(inputs))) |
|
|
173 |
|
|
|
174 |
# Blocks |
|
|
175 |
for idx, block in enumerate(self._blocks): |
|
|
176 |
drop_connect_rate = self._global_params.drop_connect_rate |
|
|
177 |
if drop_connect_rate: |
|
|
178 |
drop_connect_rate *= float(idx) / len(self._blocks) |
|
|
179 |
x = block(x, drop_connect_rate=drop_connect_rate) |
|
|
180 |
|
|
|
181 |
# Head |
|
|
182 |
x = self._swish(self._bn1(self._conv_head(x))) |
|
|
183 |
|
|
|
184 |
return x |
|
|
185 |
|
|
|
186 |
def forward(self, inputs): |
|
|
187 |
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """ |
|
|
188 |
bs = inputs.size(0) |
|
|
189 |
# Convolution layers |
|
|
190 |
x = self.extract_features(inputs) |
|
|
191 |
# Pooling and final linear layer |
|
|
192 |
x = self._avg_pooling(x) |
|
|
193 |
x = x.view(bs, -1) |
|
|
194 |
x = self._dropout(x) |
|
|
195 |
x = self._fc(x) |
|
|
196 |
return x |
|
|
197 |
|
|
|
198 |
@classmethod |
|
|
199 |
def from_name(cls, model_name, override_params=None): |
|
|
200 |
cls._check_model_name_is_valid(model_name) |
|
|
201 |
blocks_args, global_params = get_model_params(model_name, override_params) |
|
|
202 |
return cls(blocks_args, global_params) |
|
|
203 |
|
|
|
204 |
@classmethod |
|
|
205 |
def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3): |
|
|
206 |
model = cls.from_name(model_name, override_params={'num_classes': num_classes}) |
|
|
207 |
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) |
|
|
208 |
if in_channels != 3: |
|
|
209 |
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) |
|
|
210 |
out_channels = round_filters(32, model._global_params) |
|
|
211 |
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
212 |
return model |
|
|
213 |
|
|
|
214 |
@classmethod |
|
|
215 |
def get_image_size(cls, model_name): |
|
|
216 |
cls._check_model_name_is_valid(model_name) |
|
|
217 |
_, _, res, _ = efficientnet_params(model_name) |
|
|
218 |
return res |
|
|
219 |
|
|
|
220 |
@classmethod |
|
|
221 |
def _check_model_name_is_valid(cls, model_name): |
|
|
222 |
""" Validates model name. """ |
|
|
223 |
valid_models = ['efficientnet-b'+str(i) for i in range(9)] |
|
|
224 |
if model_name not in valid_models: |
|
|
225 |
raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) |
|
|
226 |
|
|
|
227 |
|
|
|
228 |
# get pretrained EfficientNet for k-classes classification |
|
|
229 |
def get_pretrained_EfficientNet(num_classes): |
|
|
230 |
model = EfficientNet.from_pretrained('efficientnet-b0') |
|
|
231 |
fc_features = model._fc.in_features |
|
|
232 |
model._fc = nn.Linear(fc_features, num_classes) |
|
|
233 |
return model |
|
|
234 |
|
|
|
235 |
|
|
|
236 |
class DAR_Effi(nn.Module): |
|
|
237 |
def __init__(self, blocks_args=None, global_params=None, in_channels=3, att_start=11): |
|
|
238 |
super(DAR_Effi, self).__init__() |
|
|
239 |
assert isinstance(blocks_args, list), 'blocks_args should be a list' |
|
|
240 |
assert len(blocks_args) > 0, 'block args must be greater than 0' |
|
|
241 |
self._global_params = global_params |
|
|
242 |
self._blocks_args = blocks_args |
|
|
243 |
self.att_start = att_start # for CA-module and NA-module |
|
|
244 |
|
|
|
245 |
# Get static or dynamic convolution depending on image size |
|
|
246 |
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) |
|
|
247 |
|
|
|
248 |
# Batch norm parameters |
|
|
249 |
bn_mom = 1 - self._global_params.batch_norm_momentum |
|
|
250 |
bn_eps = self._global_params.batch_norm_epsilon |
|
|
251 |
|
|
|
252 |
# Stem |
|
|
253 |
out_channels = round_filters(32, self._global_params) # number of output channels |
|
|
254 |
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
255 |
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
256 |
self._conv_stem_cf = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
257 |
self._bn0_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
258 |
self._conv_stem_lr = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
259 |
self._bn0_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
260 |
|
|
|
261 |
# Build blocks of Prd-Net |
|
|
262 |
self._blocks = nn.ModuleList([]) |
|
|
263 |
for block_args in self._blocks_args: |
|
|
264 |
|
|
|
265 |
# Update block input and output filters based on depth multiplier. |
|
|
266 |
block_args = block_args._replace( |
|
|
267 |
input_filters=round_filters(block_args.input_filters, self._global_params), |
|
|
268 |
output_filters=round_filters(block_args.output_filters, self._global_params), |
|
|
269 |
num_repeat=round_repeats(block_args.num_repeat, self._global_params) |
|
|
270 |
) |
|
|
271 |
|
|
|
272 |
# The first block needs to take care of stride and filter size increase. |
|
|
273 |
self._blocks.append(MBConvBlock(block_args, self._global_params)) |
|
|
274 |
if block_args.num_repeat > 1: |
|
|
275 |
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) |
|
|
276 |
for _ in range(block_args.num_repeat - 1): |
|
|
277 |
self._blocks.append(MBConvBlock(block_args, self._global_params)) |
|
|
278 |
|
|
|
279 |
# Build blocks of CF-Net |
|
|
280 |
self._blocks_cf = nn.ModuleList([]) |
|
|
281 |
for block_args in self._blocks_args: |
|
|
282 |
|
|
|
283 |
# Update block input and output filters based on depth multiplier. |
|
|
284 |
block_args = block_args._replace( |
|
|
285 |
input_filters=round_filters(block_args.input_filters, self._global_params), |
|
|
286 |
output_filters=round_filters(block_args.output_filters, self._global_params), |
|
|
287 |
num_repeat=round_repeats(block_args.num_repeat, self._global_params) |
|
|
288 |
) |
|
|
289 |
|
|
|
290 |
# The first block needs to take care of stride and filter size increase. |
|
|
291 |
self._blocks_cf.append(MBConvBlock(block_args, self._global_params)) |
|
|
292 |
if block_args.num_repeat > 1: |
|
|
293 |
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) |
|
|
294 |
for _ in range(block_args.num_repeat - 1): |
|
|
295 |
self._blocks_cf.append(MBConvBlock(block_args, self._global_params)) |
|
|
296 |
|
|
|
297 |
# Build blocks of LR-Net |
|
|
298 |
self._blocks_lr = nn.ModuleList([]) |
|
|
299 |
for block_args in self._blocks_args: |
|
|
300 |
|
|
|
301 |
# Update block input and output filters based on depth multiplier. |
|
|
302 |
block_args = block_args._replace( |
|
|
303 |
input_filters=round_filters(block_args.input_filters, self._global_params), |
|
|
304 |
output_filters=round_filters(block_args.output_filters, self._global_params), |
|
|
305 |
num_repeat=round_repeats(block_args.num_repeat, self._global_params) |
|
|
306 |
) |
|
|
307 |
|
|
|
308 |
# The first block needs to take care of stride and filter size increase. |
|
|
309 |
self._blocks_lr.append(MBConvBlock(block_args, self._global_params)) |
|
|
310 |
if block_args.num_repeat > 1: |
|
|
311 |
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) |
|
|
312 |
for _ in range(block_args.num_repeat - 1): |
|
|
313 |
self._blocks_lr.append(MBConvBlock(block_args, self._global_params)) |
|
|
314 |
|
|
|
315 |
# Head |
|
|
316 |
in_channels = block_args.output_filters # output of final block |
|
|
317 |
out_channels = round_filters(1280, self._global_params) |
|
|
318 |
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
319 |
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
320 |
self._conv_head_cf = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
321 |
self._bn1_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
322 |
self._conv_head_lr = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
323 |
self._bn1_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
324 |
|
|
|
325 |
# Final linear layer |
|
|
326 |
self._avg_pooling = nn.AdaptiveAvgPool2d(1) |
|
|
327 |
self._dropout = nn.Dropout(self._global_params.dropout_rate) |
|
|
328 |
self._fc = nn.Linear(out_channels, self._global_params.num_classes) |
|
|
329 |
self._swish = MemoryEfficientSwish() |
|
|
330 |
|
|
|
331 |
self._avg_pooling_cf = nn.AdaptiveAvgPool2d(1) |
|
|
332 |
self._dropout_cf = nn.Dropout(self._global_params.dropout_rate) |
|
|
333 |
self._fc_cf = nn.Linear(out_channels, self._global_params.num_classes) |
|
|
334 |
self._swish_cf = MemoryEfficientSwish() |
|
|
335 |
|
|
|
336 |
self._avg_pooling_lr = nn.AdaptiveAvgPool2d(1) |
|
|
337 |
self._dropout_lr = nn.Dropout(self._global_params.dropout_rate) |
|
|
338 |
self._fc_lr = nn.Linear(out_channels, self._global_params.num_classes) |
|
|
339 |
self._swish_lr = MemoryEfficientSwish() |
|
|
340 |
|
|
|
341 |
def set_swish(self, memory_efficient=True): |
|
|
342 |
"""Sets swish function as memory efficient (for training) or standard (for export)""" |
|
|
343 |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
344 |
for block in self._blocks: |
|
|
345 |
block.set_swish(memory_efficient) |
|
|
346 |
|
|
|
347 |
self._swish_cf = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
348 |
for block_cf in self._blocks_cf: |
|
|
349 |
block_cf.set_swish(memory_efficient) |
|
|
350 |
|
|
|
351 |
self._swish_lr = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
352 |
for block_lr in self._blocks_lr: |
|
|
353 |
block_lr.set_swish(memory_efficient) |
|
|
354 |
|
|
|
355 |
def attention(self, f_prd, f_cf, f_lr): |
|
|
356 |
w_cf = 1 - torch.sigmoid(f_cf) |
|
|
357 |
add_cf = w_cf * f_prd |
|
|
358 |
|
|
|
359 |
w_lr = 1 - abs(torch.sigmoid(f_prd)-torch.sigmoid(f_lr)) |
|
|
360 |
add_lr = w_lr * f_prd |
|
|
361 |
|
|
|
362 |
f_prd = f_prd + add_cf + add_lr |
|
|
363 |
return f_prd |
|
|
364 |
|
|
|
365 |
def extract_features(self, inputs): |
|
|
366 |
""" Returns output of the final convolution layer """ |
|
|
367 |
|
|
|
368 |
# Stem |
|
|
369 |
x = self._swish(self._bn0(self._conv_stem(inputs))) |
|
|
370 |
x_cf = self._swish_cf(self._bn0_cf(self._conv_stem_cf(inputs))) |
|
|
371 |
x_lr = self._swish_lr(self._bn0_lr(self._conv_stem_lr(inputs))) |
|
|
372 |
|
|
|
373 |
# Blocks |
|
|
374 |
for idx, block in enumerate(self._blocks): |
|
|
375 |
block_cf = self._blocks_cf[idx] |
|
|
376 |
block_lr = self._blocks_lr[idx] |
|
|
377 |
|
|
|
378 |
drop_connect_rate = self._global_params.drop_connect_rate |
|
|
379 |
if drop_connect_rate: |
|
|
380 |
drop_connect_rate *= float(idx) / len(self._blocks) |
|
|
381 |
|
|
|
382 |
x = block(x, drop_connect_rate=drop_connect_rate) |
|
|
383 |
x_cf = block_cf(x_cf, drop_connect_rate=drop_connect_rate) |
|
|
384 |
x_lr = block_lr(x_lr, drop_connect_rate=drop_connect_rate) |
|
|
385 |
|
|
|
386 |
if idx >= self.att_start: |
|
|
387 |
x = self.attention(x, x_cf, x_lr) |
|
|
388 |
|
|
|
389 |
# Head |
|
|
390 |
x = self._swish(self._bn1(self._conv_head(x))) |
|
|
391 |
x_cf = self._swish_cf(self._bn1_cf(self._conv_head_cf(x_cf))) |
|
|
392 |
x_lr = self._swish_lr(self._bn1_lr(self._conv_head_lr(x_lr))) |
|
|
393 |
|
|
|
394 |
return x, x_cf, x_lr |
|
|
395 |
|
|
|
396 |
def forward(self, inputs): |
|
|
397 |
bs = inputs.size(0) |
|
|
398 |
# Convolution layers |
|
|
399 |
x, x_cf, x_lr = self.extract_features(inputs) |
|
|
400 |
|
|
|
401 |
# Pooling and final linear layer |
|
|
402 |
x = self._avg_pooling(x) |
|
|
403 |
x = x.view(bs, -1) |
|
|
404 |
x = self._dropout(x) |
|
|
405 |
x = self._fc(x) |
|
|
406 |
|
|
|
407 |
x_cf = self._avg_pooling_cf(x_cf) |
|
|
408 |
x_cf = x_cf.view(bs, -1) |
|
|
409 |
x_cf = self._dropout_cf(x_cf) |
|
|
410 |
x_cf = self._fc_cf(x_cf) |
|
|
411 |
|
|
|
412 |
x_lr = self._avg_pooling_lr(x_lr) |
|
|
413 |
x_lr = x_lr.view(bs, -1) |
|
|
414 |
x_lr = self._dropout_lr(x_lr) |
|
|
415 |
x_lr = self._fc_lr(x_lr) |
|
|
416 |
|
|
|
417 |
return x, x_cf, x_lr |
|
|
418 |
|
|
|
419 |
@classmethod |
|
|
420 |
def from_name(cls, model_name, override_params=None, in_channels=3, att_start=11): |
|
|
421 |
cls._check_model_name_is_valid(model_name) |
|
|
422 |
blocks_args, global_params = get_model_params(model_name, override_params) |
|
|
423 |
return cls(blocks_args, global_params, in_channels, att_start) |
|
|
424 |
|
|
|
425 |
@classmethod |
|
|
426 |
def get_image_size(cls, model_name): |
|
|
427 |
cls._check_model_name_is_valid(model_name) |
|
|
428 |
_, _, res, _ = efficientnet_params(model_name) |
|
|
429 |
return res |
|
|
430 |
|
|
|
431 |
@classmethod |
|
|
432 |
def _check_model_name_is_valid(cls, model_name): |
|
|
433 |
""" Validates model name. """ |
|
|
434 |
valid_models = ['efficientnet-b'+str(i) for i in range(9)] |
|
|
435 |
if model_name not in valid_models: |
|
|
436 |
raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) |
|
|
437 |
|
|
|
438 |
|
|
|
439 |
def get_pretrained_DAR(prd_params, cf_params, lr_params, num_classes): |
|
|
440 |
|
|
|
441 |
dar_model = DAR_Effi.from_name('efficientnet-b0') |
|
|
442 |
fc_features = dar_model._fc.in_features |
|
|
443 |
dar_model._fc = nn.Linear(fc_features, num_classes) |
|
|
444 |
dar_model._fc_cf = nn.Linear(fc_features, num_classes) |
|
|
445 |
dar_model._fc_lr = nn.Linear(fc_features, num_classes) |
|
|
446 |
dar_params = dar_model.state_dict() |
|
|
447 |
|
|
|
448 |
for k, v in prd_params.items(): |
|
|
449 |
index_point = k.find('.') |
|
|
450 |
k_apart = k[0:index_point] |
|
|
451 |
k_bpart = k[index_point:len(k)] |
|
|
452 |
k_cf = k_apart + '_cf' + k_bpart |
|
|
453 |
k_lr = k_apart + '_lr' + k_bpart |
|
|
454 |
|
|
|
455 |
dar_params[k] = prd_params[k] |
|
|
456 |
dar_params[k_cf] = cf_params[k] |
|
|
457 |
dar_params[k_lr] = lr_params[k] |
|
|
458 |
|
|
|
459 |
dar_model.load_state_dict(dar_params) |
|
|
460 |
return dar_model |