a b/HTNet/multi-modality/resnet.py
1
import torch
2
import torch.nn as nn
3
from collections import OrderedDict
4
#from .utils import load_state_dict_from_url
5
6
7
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'Bottleneck',
8
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
9
           'wide_resnet50_2', 'wide_resnet101_2']
10
11
12
model_urls = {
13
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
19
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
20
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
21
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
22
}
23
24
25
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
26
    """3x3 convolution with padding"""
27
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
29
30
31
def conv1x1(in_planes, out_planes, stride=1):
32
    """1x1 convolution"""
33
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
34
35
36
class BasicBlock(nn.Module):
37
    expansion = 1
38
39
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
40
                 base_width=64, dilation=1, norm_layer=None):
41
        super(BasicBlock, self).__init__()
42
        if norm_layer is None:
43
            norm_layer = nn.BatchNorm2d
44
        if groups != 1 or base_width != 64:
45
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
46
        if dilation > 1:
47
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49
        self.conv1 = conv3x3(inplanes, planes, stride)
50
        self.bn1 = norm_layer(planes)
51
        self.relu = nn.ReLU(inplace=True)
52
        self.conv2 = conv3x3(planes, planes)
53
        self.bn2 = norm_layer(planes)
54
        self.downsample = downsample
55
        self.stride = stride
56
57
    def forward(self, x):
58
        identity = x
59
60
        out = self.conv1(x)
61
        out = self.bn1(out)
62
        out = self.relu(out)
63
64
        out = self.conv2(out)
65
        out = self.bn2(out)
66
67
        if self.downsample is not None:
68
            identity = self.downsample(x)
69
70
        out += identity
71
        out = self.relu(out)
72
73
        return out
74
75
76
class Bottleneck(nn.Module):
77
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
78
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
79
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
80
    # This variant is also known as ResNet V1.5 and improves accuracy according to
81
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
82
83
    expansion = 4
84
85
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
86
                 base_width=64, dilation=1, norm_layer=None):
87
        super(Bottleneck, self).__init__()
88
        if norm_layer is None:
89
            norm_layer = nn.BatchNorm2d
90
        width = int(planes * (base_width / 64.)) * groups
91
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92
        self.conv1 = conv1x1(inplanes, width)
93
        self.bn1 = norm_layer(width)
94
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
95
        self.bn2 = norm_layer(width)
96
        self.conv3 = conv1x1(width, planes * self.expansion)
97
        self.bn3 = norm_layer(planes * self.expansion)
98
        self.relu = nn.ReLU(inplace=True)
99
        self.downsample = downsample
100
        self.stride = stride
101
102
    def forward(self, x):
103
        identity = x
104
105
        out = self.conv1(x)
106
        out = self.bn1(out)
107
        out = self.relu(out)
108
109
        out = self.conv2(out)
110
        out = self.bn2(out)
111
        out = self.relu(out)
112
113
        out = self.conv3(out)
114
        out = self.bn3(out)
115
116
        if self.downsample is not None:
117
            identity = self.downsample(x)
118
119
        out += identity
120
        out = self.relu(out)
121
122
        return out
123
124
125
class ResNet(nn.Module):
126
127
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
128
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
129
                 norm_layer=None, antibody_nums=6):
130
        super(ResNet, self).__init__()
131
        if norm_layer is None:
132
            norm_layer = nn.BatchNorm2d
133
        self._norm_layer = norm_layer
134
135
        self.inplanes = 64
136
        self.dilation = 1
137
        if replace_stride_with_dilation is None:
138
            # each element in the tuple indicates if we should replace
139
            # the 2x2 stride with a dilated convolution instead
140
            replace_stride_with_dilation = [False, False, False]
141
        if len(replace_stride_with_dilation) != 3:
142
            raise ValueError("replace_stride_with_dilation should be None "
143
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
144
        self.groups = groups
145
        self.base_width = width_per_group
146
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
147
                               bias=False)
148
        self.bn1 = norm_layer(self.inplanes)
149
        self.relu = nn.ReLU(inplace=True)
150
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151
        self.layer1 = self._make_layer(block, 64, layers[0])
152
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
153
                                       dilate=replace_stride_with_dilation[0])
154
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
155
                                       dilate=replace_stride_with_dilation[1])
156
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
157
                                       dilate=replace_stride_with_dilation[2])
158
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
159
        self.fc = nn.Linear(512 * block.expansion, num_classes)
160
161
        for m in self.modules():
162
            if isinstance(m, nn.Conv2d):
163
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
164
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
165
                nn.init.constant_(m.weight, 1)
166
                nn.init.constant_(m.bias, 0)
167
168
        # Zero-initialize the last BN in each residual branch,
169
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
170
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
171
        if zero_init_residual:
172
            for m in self.modules():
173
                if isinstance(m, Bottleneck):
174
                    nn.init.constant_(m.bn3.weight, 0)
175
                elif isinstance(m, BasicBlock):
176
                    nn.init.constant_(m.bn2.weight, 0)
177
                
178
        self.antibody_net = nn.Sequential(OrderedDict([
179
            ('Ab_fc0'  , nn.Linear(antibody_nums, 1024, bias=True)),
180
            ('Ab_norm0', nn.GroupNorm(1, 1024)),
181
            ('Ab_relu0', nn.ReLU(inplace=True)),
182
            ('Ab_fc1'  , nn.Linear(1024, 2048, bias=True))
183
        ]))
184
185
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
186
        norm_layer = self._norm_layer
187
        downsample = None
188
        previous_dilation = self.dilation
189
        if dilate:
190
            self.dilation *= stride
191
            stride = 1
192
        if stride != 1 or self.inplanes != planes * block.expansion:
193
            downsample = nn.Sequential(
194
                conv1x1(self.inplanes, planes * block.expansion, stride),
195
                norm_layer(planes * block.expansion),
196
            )
197
198
        layers = []
199
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
200
                            self.base_width, previous_dilation, norm_layer))
201
        self.inplanes = planes * block.expansion
202
        for _ in range(1, blocks):
203
            layers.append(block(self.inplanes, planes, groups=self.groups,
204
                                base_width=self.base_width, dilation=self.dilation,
205
                                norm_layer=norm_layer))
206
207
        return nn.Sequential(*layers)
208
209
    def _forward_impl(self, x, x1):
210
        # See note [TorchScript super()]
211
        x = self.conv1(x)
212
        x = self.bn1(x)
213
        x = self.relu(x)
214
        x = self.maxpool(x)
215
216
        x = self.layer1(x)
217
        x = self.layer2(x)
218
        x = self.layer3(x)
219
        x = self.layer4(x)
220
221
        x = self.avgpool(x)
222
        x = torch.flatten(x, 1)
223
        x1 = self.antibody_net(x1)
224
        
225
        x = self.fc(x + x1)
226
227
        return x
228
229
    def forward(self, x, x1):
230
        return self._forward_impl(x, x1)
231
232
233
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
234
    model = ResNet(block, layers, **kwargs)
235
    if pretrained:
236
        state_dict = load_state_dict_from_url(model_urls[arch],
237
                                              progress=progress)
238
        model.load_state_dict(state_dict)
239
    return model
240
241
242
def resnet18(pretrained=False, progress=True, **kwargs):
243
    r"""ResNet-18 model from
244
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
245
246
    Args:
247
        pretrained (bool): If True, returns a model pre-trained on ImageNet
248
        progress (bool): If True, displays a progress bar of the download to stderr
249
    """
250
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
251
                   **kwargs)
252
253
254
def resnet34(pretrained=False, progress=True, **kwargs):
255
    r"""ResNet-34 model from
256
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
257
258
    Args:
259
        pretrained (bool): If True, returns a model pre-trained on ImageNet
260
        progress (bool): If True, displays a progress bar of the download to stderr
261
    """
262
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
263
                   **kwargs)
264
265
266
def resnet50(pretrained=False, progress=True, **kwargs):
267
    r"""ResNet-50 model from
268
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
269
270
    Args:
271
        pretrained (bool): If True, returns a model pre-trained on ImageNet
272
        progress (bool): If True, displays a progress bar of the download to stderr
273
    """
274
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
275
                   **kwargs)
276
277
278
def resnet101(pretrained=False, progress=True, **kwargs):
279
    r"""ResNet-101 model from
280
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
281
282
    Args:
283
        pretrained (bool): If True, returns a model pre-trained on ImageNet
284
        progress (bool): If True, displays a progress bar of the download to stderr
285
    """
286
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
287
                   **kwargs)
288
289
290
def resnet152(pretrained=False, progress=True, **kwargs):
291
    r"""ResNet-152 model from
292
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
293
294
    Args:
295
        pretrained (bool): If True, returns a model pre-trained on ImageNet
296
        progress (bool): If True, displays a progress bar of the download to stderr
297
    """
298
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
299
                   **kwargs)
300
301
302
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
303
    r"""ResNeXt-50 32x4d model from
304
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
305
306
    Args:
307
        pretrained (bool): If True, returns a model pre-trained on ImageNet
308
        progress (bool): If True, displays a progress bar of the download to stderr
309
    """
310
    kwargs['groups'] = 32
311
    kwargs['width_per_group'] = 4
312
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
313
                   pretrained, progress, **kwargs)
314
315
316
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
317
    r"""ResNeXt-101 32x8d model from
318
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
319
320
    Args:
321
        pretrained (bool): If True, returns a model pre-trained on ImageNet
322
        progress (bool): If True, displays a progress bar of the download to stderr
323
    """
324
    kwargs['groups'] = 32
325
    kwargs['width_per_group'] = 8
326
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
327
                   pretrained, progress, **kwargs)
328
329
330
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
331
    r"""Wide ResNet-50-2 model from
332
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
333
334
    The model is the same as ResNet except for the bottleneck number of channels
335
    which is twice larger in every block. The number of channels in outer 1x1
336
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
337
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
338
339
    Args:
340
        pretrained (bool): If True, returns a model pre-trained on ImageNet
341
        progress (bool): If True, displays a progress bar of the download to stderr
342
    """
343
    kwargs['width_per_group'] = 64 * 2
344
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
345
                   pretrained, progress, **kwargs)
346
347
348
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
349
    r"""Wide ResNet-101-2 model from
350
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
351
352
    The model is the same as ResNet except for the bottleneck number of channels
353
    which is twice larger in every block. The number of channels in outer 1x1
354
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
355
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
356
357
    Args:
358
        pretrained (bool): If True, returns a model pre-trained on ImageNet
359
        progress (bool): If True, displays a progress bar of the download to stderr
360
    """
361
    kwargs['width_per_group'] = 64 * 2
362
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
363
                   pretrained, progress, **kwargs)