|
a |
|
b/src/models/multimodals.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from cnn_finetune import make_model |
|
|
4 |
from timm import create_model |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def cnnfinetune_freeze(self): |
|
|
8 |
for param in self.parameters(): |
|
|
9 |
param.requires_grad = False |
|
|
10 |
|
|
|
11 |
for param in self._classifier.parameters(): |
|
|
12 |
param.requires_grad = True |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def cnnfinetune_unfreeze(self): |
|
|
16 |
for param in self.parameters(): |
|
|
17 |
param.requires_grad = True |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def make_classifier(in_features, num_classes): |
|
|
21 |
return nn.Sequential( |
|
|
22 |
nn.Linear(in_features, 512), |
|
|
23 |
nn.Dropout(0.3), |
|
|
24 |
nn.Linear(512, num_classes), |
|
|
25 |
) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
class MultiModals(nn.Module): |
|
|
29 |
def __init__(self, model_name, pretrained=True, num_classes=6, dropout_p=None): |
|
|
30 |
super(MultiModals, self).__init__() |
|
|
31 |
self.model = make_model( |
|
|
32 |
model_name=model_name, |
|
|
33 |
num_classes=num_classes, |
|
|
34 |
pretrained=pretrained, |
|
|
35 |
dropout_p=dropout_p, |
|
|
36 |
# classifier_factory=make_classifier |
|
|
37 |
) |
|
|
38 |
|
|
|
39 |
in_features = self.model._classifier.in_features |
|
|
40 |
|
|
|
41 |
self._classifier = nn.Sequential( |
|
|
42 |
nn.Linear(in_features + 8, 512), |
|
|
43 |
nn.Dropout(0.3), |
|
|
44 |
nn.Linear(512, num_classes), |
|
|
45 |
) |
|
|
46 |
|
|
|
47 |
setattr(self, 'freeze', cnnfinetune_freeze) |
|
|
48 |
setattr(self, 'unfreeze', cnnfinetune_unfreeze) |
|
|
49 |
|
|
|
50 |
def forward(self, images, meta): |
|
|
51 |
x = self.model._features(images) |
|
|
52 |
x = self.model.pool(x) |
|
|
53 |
x = x.view(x.size(0), -1) |
|
|
54 |
# import pdb |
|
|
55 |
# pdb.set_trace() |
|
|
56 |
# if isinstance(x, torch.HalfTensor): |
|
|
57 |
# meta = meta.half() |
|
|
58 |
x = torch.cat([x, meta], dim=1) |
|
|
59 |
return self._classifier(x) |