Switch to side-by-side view

--- a
+++ b/BioSeqNet/resnest/torch/resnet.py
@@ -0,0 +1,305 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Hang Zhang
+## Email: zhanghang0704@gmail.com
+## Copyright (c) 2020
+##
+## LICENSE file in the root directory of this source tree 
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+"""ResNet variants"""
+import math
+import torch
+import torch.nn as nn
+
+from .splat import SplAtConv1d
+
+__all__ = ['ResNet', 'Bottleneck']
+
+class DropBlock2D(object):
+    def __init__(self, *args, **kwargs):
+        raise NotImplementedError
+
+class GlobalAvgPool1d(nn.Module):
+    def __init__(self):
+        """Global average pooling over the input's spatial dimensions"""
+        super(GlobalAvgPool1d, self).__init__()
+
+    def forward(self, inputs):
+        return nn.functional.adaptive_avg_pool1d(inputs, 1).view(inputs.size(0), -1)
+
+class Bottleneck(nn.Module):
+    """ResNet Bottleneck
+    """
+    # pylint: disable=unused-argument
+    expansion = 4
+    def __init__(self, inplanes, planes, stride=1, downsample=None,
+                 radix=1, cardinality=1, bottleneck_width=64,
+                 avd=False, avd_first=False, dilation=1, is_first=False,
+                 rectified_conv=False, rectify_avg=False,
+                 norm_layer=None, dropblock_prob=0.0, last_gamma=False):
+        super(Bottleneck, self).__init__()
+        group_width = int(planes * (bottleneck_width / 64.)) * cardinality
+        self.conv1 = nn.Conv1d(inplanes, group_width, kernel_size=1, bias=False)
+        self.bn1 = norm_layer(group_width)
+        self.dropblock_prob = dropblock_prob
+        self.radix = radix
+        self.avd = avd and (stride > 1 or is_first)
+        self.avd_first = avd_first
+
+        if self.avd:
+            self.avd_layer = nn.AvgPool1d(3, stride, padding=1)
+            stride = 1
+
+        if dropblock_prob > 0.0:
+            self.dropblock1 = DropBlock2D(dropblock_prob, 3)
+            if radix == 1:
+                self.dropblock2 = DropBlock2D(dropblock_prob, 3)
+            self.dropblock3 = DropBlock2D(dropblock_prob, 3)
+
+        if radix >= 1:
+            self.conv2 = SplAtConv1d(
+                group_width, group_width, kernel_size=3,
+                stride=stride, padding=dilation,
+                dilation=dilation, groups=cardinality, bias=False,
+                radix=radix, rectify=rectified_conv,
+                rectify_avg=rectify_avg,
+                norm_layer=norm_layer,
+                dropblock_prob=dropblock_prob)
+        elif rectified_conv:
+            from rfconv import RFConv1d
+            self.conv2 = RFConv1d(
+                group_width, group_width, kernel_size=3, stride=stride,
+                padding=dilation, dilation=dilation,
+                groups=cardinality, bias=False,
+                average_mode=rectify_avg)
+            self.bn2 = norm_layer(group_width)
+        else:
+            self.conv2 = nn.Conv1d(
+                group_width, group_width, kernel_size=3, stride=stride,
+                padding=dilation, dilation=dilation,
+                groups=cardinality, bias=False)
+            self.bn2 = norm_layer(group_width)
+
+        self.conv3 = nn.Conv1d(
+            group_width, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = norm_layer(planes*4)
+
+        if last_gamma:
+            from torch.nn.init import zeros_
+            zeros_(self.bn3.weight)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.dilation = dilation
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        if self.dropblock_prob > 0.0:
+            out = self.dropblock1(out)
+        out = self.relu(out)
+
+        if self.avd and self.avd_first:
+            out = self.avd_layer(out)
+
+        out = self.conv2(out)
+        if self.radix == 0:
+            out = self.bn2(out)
+            if self.dropblock_prob > 0.0:
+                out = self.dropblock2(out)
+            out = self.relu(out)
+
+        if self.avd and not self.avd_first:
+            out = self.avd_layer(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+        if self.dropblock_prob > 0.0:
+            out = self.dropblock3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+class ResNet(nn.Module):
+    """ResNet Variants
+
+    Parameters
+    ----------
+    block : Block
+        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
+    layers : list of int
+        Numbers of layers in each block
+    classes : int, default 1000
+        Number of classification classes.
+    dilated : bool, default False
+        Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
+        typically used in Semantic Segmentation.
+    norm_layer : object
+        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
+        for Synchronized Cross-GPU BachNormalization).
+
+    Reference:
+
+        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+
+        - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
+    """
+    # pylint: disable=unused-variable
+    def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
+                 num_classes=1000, dilated=False, dilation=1,
+                 deep_stem=False, stem_width=64, avg_down=False,
+                 rectified_conv=False, rectify_avg=False,
+                 avd=False, avd_first=False,
+                 final_drop=0.0, dropblock_prob=0,
+                 last_gamma=False, norm_layer=nn.BatchNorm1d, num_channels=4):
+        self.cardinality = groups
+        self.bottleneck_width = bottleneck_width
+        # ResNet-D params
+        self.inplanes = stem_width*2 if deep_stem else 64
+        self.avg_down = avg_down
+        self.last_gamma = last_gamma
+        # ResNeSt params
+        self.radix = radix
+        self.avd = avd
+        self.avd_first = avd_first
+
+        super(ResNet, self).__init__()
+        self.rectified_conv = rectified_conv
+        self.rectify_avg = rectify_avg
+        if rectified_conv:
+            from rfconv import RFConv1d
+            conv_layer = RFConv1d
+        else:
+            conv_layer = nn.Conv1d
+        conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
+        if deep_stem:
+            self.conv1 = nn.Sequential(
+                conv_layer(num_channels, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
+                norm_layer(stem_width),
+                nn.ReLU(inplace=True),
+                conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
+                norm_layer(stem_width),
+                nn.ReLU(inplace=True),
+                conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
+            )
+        else:
+            self.conv1 = conv_layer(num_channels, 64, kernel_size=7, stride=2, padding=3,
+                                   bias=False, **conv_kwargs)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
+        if dilated or dilation == 4:
+            self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
+                                           dilation=2, norm_layer=norm_layer,
+                                           dropblock_prob=dropblock_prob)
+            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
+                                           dilation=4, norm_layer=norm_layer,
+                                           dropblock_prob=dropblock_prob)
+        elif dilation==2:
+            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                           dilation=1, norm_layer=norm_layer,
+                                           dropblock_prob=dropblock_prob)
+            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
+                                           dilation=2, norm_layer=norm_layer,
+                                           dropblock_prob=dropblock_prob)
+        else:
+            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                           norm_layer=norm_layer,
+                                           dropblock_prob=dropblock_prob)
+            self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                           norm_layer=norm_layer,
+                                           dropblock_prob=dropblock_prob)
+        self.avgpool = GlobalAvgPool1d()
+        self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        #for m in self.modules():
+        #    if isinstance(m, nn.Conv1d):
+        #        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+        #        m.weight.data.normal_(0, math.sqrt(2. / n))
+        #    elif isinstance(m, norm_layer):
+        #        m.weight.data.fill_(1)
+        #        m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
+                    dropblock_prob=0.0, is_first=True):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            down_layers = []
+            if self.avg_down:
+                if dilation == 1:
+                    down_layers.append(nn.AvgPool1d(kernel_size=stride, stride=stride,
+                                                    ceil_mode=True, count_include_pad=False))
+                else:
+                    down_layers.append(nn.AvgPool1d(kernel_size=1, stride=1,
+                                                    ceil_mode=True, count_include_pad=False))
+                down_layers.append(nn.Conv1d(self.inplanes, planes * block.expansion,
+                                             kernel_size=1, stride=1, bias=False))
+            else:
+                down_layers.append(nn.Conv1d(self.inplanes, planes * block.expansion,
+                                             kernel_size=1, stride=stride, bias=False))
+            down_layers.append(norm_layer(planes * block.expansion))
+            downsample = nn.Sequential(*down_layers)
+
+        layers = []
+        if dilation == 1 or dilation == 2:
+            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
+                                radix=self.radix, cardinality=self.cardinality,
+                                bottleneck_width=self.bottleneck_width,
+                                avd=self.avd, avd_first=self.avd_first,
+                                dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
+                                rectify_avg=self.rectify_avg,
+                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
+                                last_gamma=self.last_gamma))
+        elif dilation == 4:
+            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
+                                radix=self.radix, cardinality=self.cardinality,
+                                bottleneck_width=self.bottleneck_width,
+                                avd=self.avd, avd_first=self.avd_first,
+                                dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
+                                rectify_avg=self.rectify_avg,
+                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
+                                last_gamma=self.last_gamma))
+        else:
+            raise RuntimeError("=> unknown dilation size: {}".format(dilation))
+
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes,
+                                radix=self.radix, cardinality=self.cardinality,
+                                bottleneck_width=self.bottleneck_width,
+                                avd=self.avd, avd_first=self.avd_first,
+                                dilation=dilation, rectified_conv=self.rectified_conv,
+                                rectify_avg=self.rectify_avg,
+                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
+                                last_gamma=self.last_gamma))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        #x = x.view(x.size(0), -1)
+        x = torch.flatten(x, 1)
+        if self.drop:
+            x = self.drop(x)
+        x = self.fc(x)
+
+        return x