Diff of /src/model.py [000000] .. [f45789]

Switch to unified view

a b/src/model.py
1
import torch
2
import torch.nn as nn
3
import torch.optim as optim
4
import numpy as np
5
import torchvision
6
from torchvision import datasets, models, transforms
7
import matplotlib.pyplot as plt
8
import time
9
import os
10
import copy
11
12
import torch.nn.functional as F
13
14
RESNET = ['resnet18','resnet34','resnet50','resnet101','resnet152']
15
16
def set_parameter_requires_grad(model, feature_extracting):
17
    if feature_extracting:
18
        for param in model.parameters():
19
            param.requires_grad = False
20
21
def initialize_model(conf):
22
    model_name = conf['model']['name']
23
    feature_extract = conf['model']['feature_extract']
24
    use_pretrained = conf['model']['use_pretrained']
25
    print_model = conf['model']['print_model']
26
    num_classes = len(conf['data']['classes'])
27
28
    if model_name in RESNET:
29
        model = getattr(models, model_name)(pretrained=use_pretrained)
30
        set_parameter_requires_grad(model, feature_extract)
31
        num_ftrs = model.fc.in_features
32
        model.fc = nn.Linear(num_ftrs, num_classes)
33
    elif model_name == 'efficientdet_d0':
34
        model = EfficientClassification(num_classes)
35
    else:
36
        print("Invalid model name, exiting...")
37
        exit()
38
39
    if print_model: print(model)
40
    model.name = model_name
41
    return model
42
43
class EfficientClassification(nn.Module):
44
45
    def __init__(self, num_classes):
46
        super(EfficientClassification, self).__init__()
47
        from effdet import create_model
48
        self.effdet = create_model(model_name='efficientdet_d0')
49
        self.effdet.box_net = nn.Identity()
50
        self.effdet.class_net = nn.Identity()
51
        self.resnet = models.resnet18(pretrained=True)
52
53
        num_ftrs = self.resnet.fc.in_features
54
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)
55
56
        self.deconv0 = nn.ConvTranspose2d(in_channels=64,
57
                                          out_channels=16,
58
                                          kernel_size=19,
59
                                          stride=3,
60
                                          padding=1,
61
                                          dilation=2)
62
63
        self.deconv1 = nn.ConvTranspose2d(in_channels=64,
64
                                          out_channels=12,
65
                                          kernel_size=9,
66
                                          stride=7,
67
                                          padding=1,
68
                                          dilation=1)
69
70
        self.deconv2 = nn.ConvTranspose2d(in_channels=64,
71
                                          out_channels=8,
72
                                          kernel_size=24,
73
                                          stride=9,
74
                                          padding=2,
75
                                          dilation=4)
76
77
        self.deconv3 = nn.ConvTranspose2d(in_channels=64,
78
                                          out_channels=4,
79
                                          kernel_size=28,
80
                                          stride=9,
81
                                          padding=1,
82
                                          dilation=6)
83
84
        self.deconv4 = nn.ConvTranspose2d(in_channels=64,
85
                                          out_channels=2,
86
                                          kernel_size=30,
87
                                          stride=8,
88
                                          padding=2,
89
                                          dilation=7)
90
91
        self.conv0 = nn.Conv2d(in_channels=42,
92
                              out_channels=16,
93
                              kernel_size=5,
94
                              padding=2)
95
96
        self.conv1 = nn.Conv2d(in_channels=16,
97
                              out_channels=3,
98
                              kernel_size=3,
99
                              padding=1)
100
101
    def forward(self, x):
102
        # EffNet + BiFPN
103
        fpn_out, _ = self.effdet(x)
104
105
        # Convolution Transpose
106
        out0 = self.deconv0(fpn_out[0])
107
        out1 = self.deconv1(fpn_out[1])
108
        out2 = self.deconv2(fpn_out[2])
109
        out3 = self.deconv3(fpn_out[3])
110
        out4 = self.deconv4(fpn_out[4])
111
        deconv_out = torch.cat([out0,out1,out2,out3,out4], dim=1)
112
113
        # Convolution
114
        conv_out = self.conv1(self.conv0(deconv_out))
115
116
        # Resnet18
117
        out = self.resnet(conv_out)
118
        return out
119
120
class EfficientClassification2(nn.Module):
121
122
    def __init__(self, num_classes):
123
        super(EfficientClassification2, self).__init__()
124
        from effdet import create_model
125
        self.effdet = create_model(model_name='efficientdet_d0')
126
        self.effdet.box_net = nn.Identity()
127
        self.effdet.class_net = nn.Identity()
128
129
        # In features from FPN
130
        fc_in_features = [64 * i*i for i in [64,32,16,8,4]]
131
        mid = 64
132
        self.fc0 = nn.Linear(fc_in_features[0], mid)
133
        self.fc1 = nn.Linear(fc_in_features[1], mid)
134
        self.fc2 = nn.Linear(fc_in_features[2], mid)
135
        self.fc3 = nn.Linear(fc_in_features[3], mid)
136
        self.fc4 = nn.Linear(fc_in_features[4], mid)
137
        self.fc_out = nn.Linear(5 * mid, num_classes)
138
139
    def forward(self, x):
140
        fpn_out, _ = self.effdet(x)
141
        fpn_out = list(map(lambda t: torch.flatten(t, start_dim=1), fpn_out))
142
        out0 = self.fc0(fpn_out[0])
143
        out1 = self.fc1(fpn_out[1])
144
        out2 = self.fc2(fpn_out[2])
145
        out3 = self.fc3(fpn_out[3])
146
        out4 = self.fc4(fpn_out[4])
147
        fc_outs = torch.cat([out0,out1,out2,out3,out4], dim=1)
148
        out = self.fc_out(fc_outs)
149
        return out
150
151
if __name__ == '__main__':
152
153
    x = torch.randn(20, 3, 512, 512)
154
    model = EfficientClassification(num_classes=2)
155
    fpn_out = model(x)
156
    print('FIN')