Diff of /src/models/multimodals.py [000000] .. [95f789]

Switch to unified view

a b/src/models/multimodals.py
1
import torch
2
import torch.nn as nn
3
from cnn_finetune import make_model
4
from timm import create_model
5
6
7
def cnnfinetune_freeze(self):
8
    for param in self.parameters():
9
        param.requires_grad = False
10
11
    for param in self._classifier.parameters():
12
        param.requires_grad = True
13
14
15
def cnnfinetune_unfreeze(self):
16
    for param in self.parameters():
17
        param.requires_grad = True
18
19
20
def make_classifier(in_features, num_classes):
21
    return nn.Sequential(
22
        nn.Linear(in_features, 512),
23
        nn.Dropout(0.3),
24
        nn.Linear(512, num_classes),
25
    )
26
27
28
class MultiModals(nn.Module):
29
    def __init__(self, model_name, pretrained=True, num_classes=6, dropout_p=None):
30
        super(MultiModals, self).__init__()
31
        self.model = make_model(
32
            model_name=model_name,
33
            num_classes=num_classes,
34
            pretrained=pretrained,
35
            dropout_p=dropout_p,
36
            # classifier_factory=make_classifier
37
        )
38
39
        in_features = self.model._classifier.in_features
40
41
        self._classifier = nn.Sequential(
42
            nn.Linear(in_features + 8, 512),
43
            nn.Dropout(0.3),
44
            nn.Linear(512, num_classes),
45
        )
46
47
        setattr(self, 'freeze', cnnfinetune_freeze)
48
        setattr(self, 'unfreeze', cnnfinetune_unfreeze)
49
50
    def forward(self, images, meta):
51
        x = self.model._features(images)
52
        x = self.model.pool(x)
53
        x = x.view(x.size(0), -1)
54
        # import pdb
55
        # pdb.set_trace()
56
        # if isinstance(x, torch.HalfTensor):
57
        #     meta = meta.half()
58
        x = torch.cat([x, meta], dim=1)
59
        return self._classifier(x)