Diff of /algorithms/classifiers.py [000000] .. [a18f15]

Switch to unified view

a b/algorithms/classifiers.py
1
import os, sys
2
import torch
3
import torchvision
4
from torch import nn
5
6
sys.path.append(os.getcwd())
7
from algorithms.arch.resnet import loadResnetBackbone
8
import utilities.runUtils as rutl
9
10
11
##================= CLassifier Wrapper =========================================
12
13
class ClassifierNet(nn.Module):
14
    def __init__(self, arch, fc_layer_sizes=[512,1000],
15
                    feature_dropout=0, classifier_dropout=0,
16
                    feature_freeze = False, feature_bnorm = False,
17
                    torch_pretrain=None):
18
        super().__init__()
19
        rutl.START_SEED(7)
20
21
        self.fc_layer_sizes = fc_layer_sizes
22
23
        # Feature Extractor
24
        self.backbone,self.feat_outsize = loadResnetBackbone(arch=arch,
25
                                            torch_pretrain=torch_pretrain,
26
                                            freeze=feature_freeze)
27
        fx_layers = []
28
        if feature_bnorm:
29
            fx_layers.append(nn.BatchNorm1d(self.feat_outsize, affine=False))
30
        fx_layers.append(nn.Dropout(p=feature_dropout))
31
32
        self.featx_proc = nn.Sequential(*fx_layers)
33
34
        # Classifier
35
        sizes = [self.feat_outsize] + list(self.fc_layer_sizes)
36
        layers = []
37
        for i in range(len(sizes) - 2):
38
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
39
            layers.append(nn.LayerNorm(sizes[i + 1]))
40
            layers.append(nn.ReLU(inplace=True))
41
            layers.append(nn.Dropout(p=classifier_dropout))
42
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
43
44
        self.classifier = nn.Sequential(*layers)
45
46
47
    def forward(self, x):
48
        x = self.backbone(x)
49
        x = self.featx_proc(x)
50
        out = self.classifier(x)
51
52
        return out
53
54
55
56
if __name__ == "__main__":
57
58
    from torchinfo import summary
59
60
    model = ClassifierNet(arch='efficientnet_b0', fc_layer_sizes=[64,8],
61
                    feature_dropout=0, classifier_dropout=0,
62
                    torch_pretrain=None)
63
    summary(model, (1, 3, 200, 200))
64
    print(model)