Diff of /networks.py [000000] .. [77dc1e]

Switch to unified view

a b/networks.py
1
import numpy as np
2
import torch
3
from torchvision import models
4
from efficientnet_pytorch import EfficientNet
5
import torch.nn.functional as F
6
import timm
7
8
9
def model_builder(architecture_name, output=6):
10
    if architecture_name.startswith("resnet"):
11
        net = eval("models." + architecture_name)(pretrained=True)
12
        net.fc = torch.nn.Linear(net.fc.in_features, output)
13
        return net
14
    elif architecture_name.startswith("efficientnet"):
15
        n = int(architecture_name[-1])
16
        net = EfficientNet.from_pretrained(f'efficientnet-b{n}')
17
        net._fc = torch.nn.Linear(net._fc.in_features, output)
18
        return net
19
    elif architecture_name.startswith("densenet"):
20
        net = eval("models." + architecture_name)(pretrained=True)
21
        net.classifier = torch.nn.Linear(net.classifier.in_features, output)
22
        return net
23
    elif architecture_name == "vgg19":
24
        net = models.vgg19_bn(pretrained=True)
25
        net.classifier[6] = torch.nn.Linear(net.classifier[6].in_features, output)
26
        return net
27
    elif architecture_name == "seresnext":
28
        net = timm.create_model('gluon_seresnext101_32x4d', pretrained=True)
29
        net.fc = torch.nn.Linear(net.fc.in_features, 6)
30
        return net
31
32
33
# https://github.com/pudae/kaggle-hpa/blob/master/losses/loss_factory.py
34
def binary_focal_loss(gamma=2, **_):
35
36
    def func(input, target):
37
        assert target.size() == input.size()
38
39
        max_val = (-input).clamp(min=0)
40
41
        loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
42
        invprobs = F.logsigmoid(-input * (target * 2 - 1))
43
        loss = (invprobs * gamma).exp() * loss
44
        return loss.mean()
45
46
    return func
47
48
49
class Windowing(torch.nn.Module):
50
    def __init__(self, u=1, epsilon=1e-3, window_length=50, window_width=130, transform="sigmoid"):
51
        """
52
        Practical Window Setting Optimization for Medical Image Deep Learning https://arxiv.org/pdf/1812.00572.pdf
53
        :param u: Upper bound for image values, e.g. 255
54
        :param epsilon:
55
        :param window_length:
56
        :param window_width:
57
        """
58
        super(Windowing, self).__init__()
59
        self.conv = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)
60
        self.u = u
61
62
        if transform == "sigmoid":
63
            weight = (2 / window_width) * np.log((u/epsilon) - 1)
64
            bias = (-2 * window_length / window_width) * np.log((u / epsilon) - 1)
65
            self.transform = self.sigmoid_transform
66
        else:  # Linear window
67
            weight = u / window_width
68
            bias = (-u / window_width) * (window_length - (window_width / 2))
69
            self.transform = self.linear_transform
70
71
        self.conv.weight = torch.nn.Parameter(weight * torch.ones_like(self.conv.weight))
72
        self.conv.bias = torch.nn.Parameter(bias * torch.ones_like(self.conv.bias))
73
74
    def linear_transform(self, x):
75
        return torch.relu(torch.max(x, torch.tensor(self.u)))
76
77
    def sigmoid_transform(self, x):
78
        return self.u * torch.sigmoid(x)
79
80
    def forward(self, img):
81
        return self.transform(self.conv(img))
82
83
84
class ResNetModel(torch.nn.Module):
85
    def __init__(self, step_train=False, output=6):
86
        super(ResNetModel, self).__init__()
87
        self.net = models.resnet50(pretrained=True)
88
        self.net.fc = torch.nn.Linear(self.net.fc.in_features, output)
89
        self.blocks = ["layer1", "layer2", "layer3", "layer4"]
90
        self.frozen_blocks = 4
91
92
        # Gradually unfreeze layers throughout training
93
        if step_train:
94
            for name, param in self.net.named_parameters():
95
                param.requires_grad_(False)
96
            self.unfreeze_layers()
97
98
    def phase1_model(self):
99
        self.net.fc = torch.nn.Linear(self.net.fc.in_features, 1)
100
101
    def phase2_model(self):
102
        self.net.fc = torch.nn.Linear(self.net.fc.in_features, 5)
103
104
    def unfreeze_layers(self, lower_bound=0):
105
        for name, param in self.net.named_parameters():
106
            if self.frozen_blocks < 0:
107
                param.requires_grad_(True)
108
            elif name.split(".")[0] in ["fc"]:
109
                param.requires_grad_(True)
110
            elif name.split(".")[0] in self.blocks[self.frozen_blocks:]:
111
                param.requires_grad_(True)
112
113
        if self.frozen_blocks >= lower_bound:
114
            self.frozen_blocks -= 1
115
116
    def forward(self, x):
117
        return self.net(x)
118
119
120
class DenseNetModel(torch.nn.Module):
121
    def __init__(self, step_train=False, output=6):
122
        super(DenseNetModel, self).__init__()
123
        self.net = models.densenet169(pretrained=True)
124
        self.net.classifier = torch.nn.Linear(self.net.classifier.in_features, output)
125
        self.blocks = ["denseblock1", "denseblock2", "denseblock3", "denseblock4"]
126
        self.frozen_blocks = 4
127
128
        # Gradually unfreeze layers throughout training
129
        if step_train:
130
            for name, param in self.net.named_parameters():
131
                param.requires_grad_(False)
132
            self.unfreeze_layers()
133
134
    def phase1_model(self):
135
        self.net.fc = torch.nn.Linear(self.net.fc.in_features, 1)
136
137
    def phase2_model(self):
138
        self.net.fc = torch.nn.Linear(self.net.fc.in_features, 5)
139
140
    def unfreeze_layers(self, lower_bound=0):
141
        for name, param in self.net.named_parameters():
142
            if self.frozen_blocks < 0:
143
                param.requires_grad_(True)
144
            elif name.split(".")[0] in ["fc"]:
145
                param.requires_grad_(True)
146
            elif name.split(".")[0] in self.blocks[self.frozen_blocks:]:
147
                param.requires_grad_(True)
148
149
        if self.frozen_blocks >= lower_bound:
150
            self.frozen_blocks -= 1
151
152
    def forward(self, x):
153
        return self.net(x)
154
155
156
class EfficientNetModel(torch.nn.Module):
157
    """
158
    # Coefficients:   width,depth,res,dropout
159
    'efficientnet-b0': (1.0, 1.0, 224, 0.2),
160
    'efficientnet-b1': (1.0, 1.1, 240, 0.2),
161
    'efficientnet-b2': (1.1, 1.2, 260, 0.3),
162
    'efficientnet-b3': (1.2, 1.4, 300, 0.3),
163
    'efficientnet-b4': (1.4, 1.8, 380, 0.4),
164
    'efficientnet-b5': (1.6, 2.2, 456, 0.4),
165
    'efficientnet-b6': (1.8, 2.6, 528, 0.5),
166
    'efficientnet-b7': (2.0, 3.1, 600, 0.5),
167
    """
168
    def __init__(self, n=0, step_train=False, output=6):
169
        super(EfficientNetModel, self).__init__()
170
        self.net = EfficientNet.from_pretrained(f'efficientnet-b{n}')
171
        self.net._fc = torch.nn.Linear(self.net._fc.in_features, output)
172
173
        filters = [block._block_args.output_filters for block in self.net._blocks]
174
        self.freeze_points = (np.where(np.diff(filters) > 0)[0])  # 6 main block groups which can be frozen/unfrozen
175
        self.frozen_blocks = 6
176
177
        # Gradually unfreeze layers throughout training
178
        if step_train:
179
            for name, param in self.net.named_parameters():
180
                param.requires_grad_(False)
181
            self.unfreeze_layers()
182
183
    def phase1_model(self):
184
        self.net._fc = torch.nn.Linear(self.net._fc.in_features, 1)
185
186
    def phase2_model(self):
187
        self.net._fc = torch.nn.Linear(self.net._fc.in_features, 5)
188
189
    def unfreeze_layers(self, lower_bound=3):
190
        try:
191
            fp = self.freeze_points[self.frozen_blocks]
192
        except IndexError:
193
            fp = np.Inf
194
195
        for name, param in self.net.named_parameters():
196
            if name.split(".")[0] in ["_conv_head", "_bn1", "_fc"]:
197
                param.requires_grad_(True)
198
            elif name.split(".")[1].isnumeric():
199
                block_number = int(name.split(".")[1])
200
                if block_number > fp:
201
                    param.requires_grad_(True)
202
203
        if self.frozen_blocks >= lower_bound:
204
            self.frozen_blocks -= 1
205
            print("Trainable blocks:", 6 - self.frozen_blocks)
206
207
    def forward(self, x):
208
        return self.net(x)