a b/BioSeqNet/resnest/torch/resnet.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
"""ResNet variants"""
9
import math
10
import torch
11
import torch.nn as nn
12
13
from .splat import SplAtConv1d
14
15
__all__ = ['ResNet', 'Bottleneck']
16
17
class DropBlock2D(object):
18
    def __init__(self, *args, **kwargs):
19
        raise NotImplementedError
20
21
class GlobalAvgPool1d(nn.Module):
22
    def __init__(self):
23
        """Global average pooling over the input's spatial dimensions"""
24
        super(GlobalAvgPool1d, self).__init__()
25
26
    def forward(self, inputs):
27
        return nn.functional.adaptive_avg_pool1d(inputs, 1).view(inputs.size(0), -1)
28
29
class Bottleneck(nn.Module):
30
    """ResNet Bottleneck
31
    """
32
    # pylint: disable=unused-argument
33
    expansion = 4
34
    def __init__(self, inplanes, planes, stride=1, downsample=None,
35
                 radix=1, cardinality=1, bottleneck_width=64,
36
                 avd=False, avd_first=False, dilation=1, is_first=False,
37
                 rectified_conv=False, rectify_avg=False,
38
                 norm_layer=None, dropblock_prob=0.0, last_gamma=False):
39
        super(Bottleneck, self).__init__()
40
        group_width = int(planes * (bottleneck_width / 64.)) * cardinality
41
        self.conv1 = nn.Conv1d(inplanes, group_width, kernel_size=1, bias=False)
42
        self.bn1 = norm_layer(group_width)
43
        self.dropblock_prob = dropblock_prob
44
        self.radix = radix
45
        self.avd = avd and (stride > 1 or is_first)
46
        self.avd_first = avd_first
47
48
        if self.avd:
49
            self.avd_layer = nn.AvgPool1d(3, stride, padding=1)
50
            stride = 1
51
52
        if dropblock_prob > 0.0:
53
            self.dropblock1 = DropBlock2D(dropblock_prob, 3)
54
            if radix == 1:
55
                self.dropblock2 = DropBlock2D(dropblock_prob, 3)
56
            self.dropblock3 = DropBlock2D(dropblock_prob, 3)
57
58
        if radix >= 1:
59
            self.conv2 = SplAtConv1d(
60
                group_width, group_width, kernel_size=3,
61
                stride=stride, padding=dilation,
62
                dilation=dilation, groups=cardinality, bias=False,
63
                radix=radix, rectify=rectified_conv,
64
                rectify_avg=rectify_avg,
65
                norm_layer=norm_layer,
66
                dropblock_prob=dropblock_prob)
67
        elif rectified_conv:
68
            from rfconv import RFConv1d
69
            self.conv2 = RFConv1d(
70
                group_width, group_width, kernel_size=3, stride=stride,
71
                padding=dilation, dilation=dilation,
72
                groups=cardinality, bias=False,
73
                average_mode=rectify_avg)
74
            self.bn2 = norm_layer(group_width)
75
        else:
76
            self.conv2 = nn.Conv1d(
77
                group_width, group_width, kernel_size=3, stride=stride,
78
                padding=dilation, dilation=dilation,
79
                groups=cardinality, bias=False)
80
            self.bn2 = norm_layer(group_width)
81
82
        self.conv3 = nn.Conv1d(
83
            group_width, planes * 4, kernel_size=1, bias=False)
84
        self.bn3 = norm_layer(planes*4)
85
86
        if last_gamma:
87
            from torch.nn.init import zeros_
88
            zeros_(self.bn3.weight)
89
        self.relu = nn.ReLU(inplace=True)
90
        self.downsample = downsample
91
        self.dilation = dilation
92
        self.stride = stride
93
94
    def forward(self, x):
95
        residual = x
96
97
        out = self.conv1(x)
98
        out = self.bn1(out)
99
        if self.dropblock_prob > 0.0:
100
            out = self.dropblock1(out)
101
        out = self.relu(out)
102
103
        if self.avd and self.avd_first:
104
            out = self.avd_layer(out)
105
106
        out = self.conv2(out)
107
        if self.radix == 0:
108
            out = self.bn2(out)
109
            if self.dropblock_prob > 0.0:
110
                out = self.dropblock2(out)
111
            out = self.relu(out)
112
113
        if self.avd and not self.avd_first:
114
            out = self.avd_layer(out)
115
116
        out = self.conv3(out)
117
        out = self.bn3(out)
118
        if self.dropblock_prob > 0.0:
119
            out = self.dropblock3(out)
120
121
        if self.downsample is not None:
122
            residual = self.downsample(x)
123
124
        out += residual
125
        out = self.relu(out)
126
127
        return out
128
129
class ResNet(nn.Module):
130
    """ResNet Variants
131
132
    Parameters
133
    ----------
134
    block : Block
135
        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
136
    layers : list of int
137
        Numbers of layers in each block
138
    classes : int, default 1000
139
        Number of classification classes.
140
    dilated : bool, default False
141
        Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
142
        typically used in Semantic Segmentation.
143
    norm_layer : object
144
        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
145
        for Synchronized Cross-GPU BachNormalization).
146
147
    Reference:
148
149
        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
150
151
        - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
152
    """
153
    # pylint: disable=unused-variable
154
    def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
155
                 num_classes=1000, dilated=False, dilation=1,
156
                 deep_stem=False, stem_width=64, avg_down=False,
157
                 rectified_conv=False, rectify_avg=False,
158
                 avd=False, avd_first=False,
159
                 final_drop=0.0, dropblock_prob=0,
160
                 last_gamma=False, norm_layer=nn.BatchNorm1d, num_channels=4):
161
        self.cardinality = groups
162
        self.bottleneck_width = bottleneck_width
163
        # ResNet-D params
164
        self.inplanes = stem_width*2 if deep_stem else 64
165
        self.avg_down = avg_down
166
        self.last_gamma = last_gamma
167
        # ResNeSt params
168
        self.radix = radix
169
        self.avd = avd
170
        self.avd_first = avd_first
171
172
        super(ResNet, self).__init__()
173
        self.rectified_conv = rectified_conv
174
        self.rectify_avg = rectify_avg
175
        if rectified_conv:
176
            from rfconv import RFConv1d
177
            conv_layer = RFConv1d
178
        else:
179
            conv_layer = nn.Conv1d
180
        conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
181
        if deep_stem:
182
            self.conv1 = nn.Sequential(
183
                conv_layer(num_channels, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
184
                norm_layer(stem_width),
185
                nn.ReLU(inplace=True),
186
                conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
187
                norm_layer(stem_width),
188
                nn.ReLU(inplace=True),
189
                conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
190
            )
191
        else:
192
            self.conv1 = conv_layer(num_channels, 64, kernel_size=7, stride=2, padding=3,
193
                                   bias=False, **conv_kwargs)
194
        self.bn1 = norm_layer(self.inplanes)
195
        self.relu = nn.ReLU(inplace=True)
196
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
197
        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
198
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
199
        if dilated or dilation == 4:
200
            self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
201
                                           dilation=2, norm_layer=norm_layer,
202
                                           dropblock_prob=dropblock_prob)
203
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
204
                                           dilation=4, norm_layer=norm_layer,
205
                                           dropblock_prob=dropblock_prob)
206
        elif dilation==2:
207
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
208
                                           dilation=1, norm_layer=norm_layer,
209
                                           dropblock_prob=dropblock_prob)
210
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
211
                                           dilation=2, norm_layer=norm_layer,
212
                                           dropblock_prob=dropblock_prob)
213
        else:
214
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
215
                                           norm_layer=norm_layer,
216
                                           dropblock_prob=dropblock_prob)
217
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
218
                                           norm_layer=norm_layer,
219
                                           dropblock_prob=dropblock_prob)
220
        self.avgpool = GlobalAvgPool1d()
221
        self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
222
        self.fc = nn.Linear(512 * block.expansion, num_classes)
223
224
        #for m in self.modules():
225
        #    if isinstance(m, nn.Conv1d):
226
        #        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
227
        #        m.weight.data.normal_(0, math.sqrt(2. / n))
228
        #    elif isinstance(m, norm_layer):
229
        #        m.weight.data.fill_(1)
230
        #        m.bias.data.zero_()
231
232
    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
233
                    dropblock_prob=0.0, is_first=True):
234
        downsample = None
235
        if stride != 1 or self.inplanes != planes * block.expansion:
236
            down_layers = []
237
            if self.avg_down:
238
                if dilation == 1:
239
                    down_layers.append(nn.AvgPool1d(kernel_size=stride, stride=stride,
240
                                                    ceil_mode=True, count_include_pad=False))
241
                else:
242
                    down_layers.append(nn.AvgPool1d(kernel_size=1, stride=1,
243
                                                    ceil_mode=True, count_include_pad=False))
244
                down_layers.append(nn.Conv1d(self.inplanes, planes * block.expansion,
245
                                             kernel_size=1, stride=1, bias=False))
246
            else:
247
                down_layers.append(nn.Conv1d(self.inplanes, planes * block.expansion,
248
                                             kernel_size=1, stride=stride, bias=False))
249
            down_layers.append(norm_layer(planes * block.expansion))
250
            downsample = nn.Sequential(*down_layers)
251
252
        layers = []
253
        if dilation == 1 or dilation == 2:
254
            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
255
                                radix=self.radix, cardinality=self.cardinality,
256
                                bottleneck_width=self.bottleneck_width,
257
                                avd=self.avd, avd_first=self.avd_first,
258
                                dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
259
                                rectify_avg=self.rectify_avg,
260
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
261
                                last_gamma=self.last_gamma))
262
        elif dilation == 4:
263
            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
264
                                radix=self.radix, cardinality=self.cardinality,
265
                                bottleneck_width=self.bottleneck_width,
266
                                avd=self.avd, avd_first=self.avd_first,
267
                                dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
268
                                rectify_avg=self.rectify_avg,
269
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
270
                                last_gamma=self.last_gamma))
271
        else:
272
            raise RuntimeError("=> unknown dilation size: {}".format(dilation))
273
274
        self.inplanes = planes * block.expansion
275
        for i in range(1, blocks):
276
            layers.append(block(self.inplanes, planes,
277
                                radix=self.radix, cardinality=self.cardinality,
278
                                bottleneck_width=self.bottleneck_width,
279
                                avd=self.avd, avd_first=self.avd_first,
280
                                dilation=dilation, rectified_conv=self.rectified_conv,
281
                                rectify_avg=self.rectify_avg,
282
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
283
                                last_gamma=self.last_gamma))
284
285
        return nn.Sequential(*layers)
286
287
    def forward(self, x):
288
        x = self.conv1(x)
289
        x = self.bn1(x)
290
        x = self.relu(x)
291
        x = self.maxpool(x)
292
293
        x = self.layer1(x)
294
        x = self.layer2(x)
295
        x = self.layer3(x)
296
        x = self.layer4(x)
297
298
        x = self.avgpool(x)
299
        #x = x.view(x.size(0), -1)
300
        x = torch.flatten(x, 1)
301
        if self.drop:
302
            x = self.drop(x)
303
        x = self.fc(x)
304
305
        return x