--- a +++ b/BioSeqNet/resnest/gluon/resnet.py @@ -0,0 +1,339 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2020 +## +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +"""ResNets, implemented in Gluon.""" +# pylint: disable=arguments-differ,unused-argument,missing-docstring +from __future__ import division + +import os +import math +from mxnet.context import cpu +from mxnet.gluon.block import HybridBlock +from mxnet.gluon import nn +from mxnet.gluon.nn import BatchNorm + +from .dropblock import DropBlock +from .splat import SplitAttentionConv + +__all__ = ['ResNet', 'Bottleneck'] + +def _update_input_size(input_size, stride): + sh, sw = (stride, stride) if isinstance(stride, int) else stride + ih, iw = (input_size, input_size) if isinstance(input_size, int) else input_size + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + input_size = (oh, ow) + return input_size + +class Bottleneck(HybridBlock): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + def __init__(self, channels, cardinality=1, bottleneck_width=64, strides=1, dilation=1, + downsample=None, previous_dilation=1, norm_layer=None, + norm_kwargs=None, last_gamma=False, + dropblock_prob=0, input_size=None, use_splat=False, + radix=2, avd=False, avd_first=False, in_channels=None, + split_drop_ratio=0, **kwargs): + super(Bottleneck, self).__init__() + group_width = int(channels * (bottleneck_width / 64.)) * cardinality + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self.dropblock_prob = dropblock_prob + self.use_splat = use_splat + self.avd = avd and (strides > 1 or previous_dilation != dilation) + self.avd_first = avd_first + if self.dropblock_prob > 0: + self.dropblock1 = DropBlock(dropblock_prob, 3, group_width, *input_size) + if self.avd: + if avd_first: + input_size = _update_input_size(input_size, strides) + self.dropblock2 = DropBlock(dropblock_prob, 3, group_width, *input_size) + if not avd_first: + input_size = _update_input_size(input_size, strides) + else: + input_size = _update_input_size(input_size, strides) + self.dropblock2 = DropBlock(dropblock_prob, 3, group_width, *input_size) + self.dropblock3 = DropBlock(dropblock_prob, 3, channels*4, *input_size) + self.conv1 = nn.Conv1D(channels=group_width, kernel_size=1, + use_bias=False, in_channels=in_channels) + self.bn1 = norm_layer(in_channels=group_width, **norm_kwargs) + self.relu1 = nn.Activation('relu') + if self.use_splat: + self.conv2 = SplitAttentionConv(channels=group_width, kernel_size=3, strides = 1 if self.avd else strides, + padding=dilation, dilation=dilation, groups=cardinality, use_bias=False, + in_channels=group_width, norm_layer=norm_layer, norm_kwargs=norm_kwargs, + radix=radix, drop_ratio=split_drop_ratio, **kwargs) + else: + self.conv2 = nn.Conv1D(channels=group_width, kernel_size=3, strides = 1 if self.avd else strides, + padding=dilation, dilation=dilation, groups=cardinality, use_bias=False, + in_channels=group_width, **kwargs) + self.bn2 = norm_layer(in_channels=group_width, **norm_kwargs) + self.relu2 = nn.Activation('relu') + self.conv3 = nn.Conv1D(channels=channels*4, kernel_size=1, use_bias=False, in_channels=group_width) + if not last_gamma: + self.bn3 = norm_layer(in_channels=channels*4, **norm_kwargs) + else: + self.bn3 = norm_layer(in_channels=channels*4, gamma_initializer='zeros', + **norm_kwargs) + if self.avd: + self.avd_layer = nn.AvgPool1D(3, strides, padding=1) + self.relu3 = nn.Activation('relu') + self.downsample = downsample + self.dilation = dilation + self.strides = strides + + def hybrid_forward(self, F, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.dropblock_prob > 0: + out = self.dropblock1(out) + out = self.relu1(out) + + if self.avd and self.avd_first: + out = self.avd_layer(out) + + if self.use_splat: + out = self.conv2(out) + if self.dropblock_prob > 0: + out = self.dropblock2(out) + else: + out = self.conv2(out) + out = self.bn2(out) + if self.dropblock_prob > 0: + out = self.dropblock2(out) + out = self.relu2(out) + + if self.avd and not self.avd_first: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + if self.dropblock_prob > 0: + out = self.dropblock3(out) + + out = out + residual + out = self.relu3(out) + + return out + +class ResNet(HybridBlock): + """ ResNet Variants Definations + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) + Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + last_gamma : bool, default False + Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero. + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + use_global_stats : bool, default False + Whether forcing BatchNorm to use global statistics instead of minibatch statistics; + optionally set to True if finetuning using ImageNet classification pretrained models. + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + # pylint: disable=unused-variable + def __init__(self, block, layers, cardinality=1, bottleneck_width=64, + classes=1000, dilated=False, dilation=1, norm_layer=BatchNorm, + norm_kwargs=None, last_gamma=False, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, use_global_stats=False, + name_prefix='', dropblock_prob=0, input_size=224, + use_splat=False, radix=2, avd=False, avd_first=False, split_drop_ratio=0, in_channels=3): + self.cardinality = cardinality + self.bottleneck_width = bottleneck_width + self.inplanes = stem_width*2 if deep_stem else 64 + self.radix = radix + self.split_drop_ratio = split_drop_ratio + self.avd_first = avd_first + super(ResNet, self).__init__(prefix=name_prefix) + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if use_global_stats: + norm_kwargs['use_global_stats'] = True + self.norm_kwargs = norm_kwargs + with self.name_scope(): + if not deep_stem: + self.conv1 = nn.Conv1D(channels=64, kernel_size=7, strides=2, + padding=3, use_bias=False, in_channels=in_channels) + else: + self.conv1 = nn.HybridSequential(prefix='conv1') + self.conv1.add(nn.Conv1D(channels=stem_width, kernel_size=3, strides=2, + padding=1, use_bias=False, in_channels=in_channels)) + self.conv1.add(norm_layer(in_channels=stem_width, **norm_kwargs)) + self.conv1.add(nn.Activation('relu')) + self.conv1.add(nn.Conv1D(channels=stem_width, kernel_size=3, strides=1, + padding=1, use_bias=False, in_channels=stem_width)) + self.conv1.add(norm_layer(in_channels=stem_width, **norm_kwargs)) + self.conv1.add(nn.Activation('relu')) + self.conv1.add(nn.Conv1D(channels=stem_width*2, kernel_size=3, strides=1, + padding=1, use_bias=False, in_channels=stem_width)) + input_size = _update_input_size(input_size, 2) + self.bn1 = norm_layer(in_channels=64 if not deep_stem else stem_width*2, + **norm_kwargs) + self.relu = nn.Activation('relu') + self.maxpool = nn.MaxPool1D(pool_size=3, strides=2, padding=1) + input_size = _update_input_size(input_size, 2) + self.layer1 = self._make_layer(1, block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer, last_gamma=last_gamma, use_splat=use_splat, + avd=avd) + self.layer2 = self._make_layer(2, block, 128, layers[1], strides=2, avg_down=avg_down, + norm_layer=norm_layer, last_gamma=last_gamma, use_splat=use_splat, + avd=avd) + input_size = _update_input_size(input_size, 2) + if dilated or dilation==4: + self.layer3 = self._make_layer(3, block, 256, layers[2], strides=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + self.layer4 = self._make_layer(4, block, 512, layers[3], strides=1, dilation=4, pre_dilation=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + elif dilation==3: + # special + self.layer3 = self._make_layer(3, block, 256, layers[2], strides=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + self.layer4 = self._make_layer(4, block, 512, layers[3], strides=2, dilation=2, pre_dilation=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + elif dilation==2: + self.layer3 = self._make_layer(3, block, 256, layers[2], strides=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + self.layer4 = self._make_layer(4, block, 512, layers[3], strides=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + else: + self.layer3 = self._make_layer(3, block, 256, layers[2], strides=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + input_size = _update_input_size(input_size, 2) + self.layer4 = self._make_layer(4, block, 512, layers[3], strides=2, + avg_down=avg_down, norm_layer=norm_layer, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd) + input_size = _update_input_size(input_size, 2) + self.avgpool = nn.GlobalAvgPool1D() + self.flat = nn.Flatten() + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Dense(in_units=512 * block.expansion, units=classes) + + def _make_layer(self, stage_index, block, planes, blocks, strides=1, dilation=1, + pre_dilation=1, avg_down=False, norm_layer=None, + last_gamma=False, + dropblock_prob=0, input_size=224, use_splat=False, avd=False): + downsample = None + if strides != 1 or self.inplanes != planes * block.expansion: + downsample = nn.HybridSequential(prefix='down%d_'%stage_index) + with downsample.name_scope(): + if avg_down: + if pre_dilation == 1: + downsample.add(nn.AvgPool1D(pool_size=strides, strides=strides, + ceil_mode=True, count_include_pad=False)) + elif strides==1: + downsample.add(nn.AvgPool1D(pool_size=1, strides=1, + ceil_mode=True, count_include_pad=False)) + else: + downsample.add(nn.AvgPool1D(pool_size=pre_dilation*strides, strides=strides, padding=1, + ceil_mode=True, count_include_pad=False)) + downsample.add(nn.Conv1D(channels=planes * block.expansion, kernel_size=1, + strides=1, use_bias=False, in_channels=self.inplanes)) + downsample.add(norm_layer(in_channels=planes * block.expansion, + **self.norm_kwargs)) + else: + downsample.add(nn.Conv1D(channels=planes * block.expansion, + kernel_size=1, strides=strides, use_bias=False, + in_channels=self.inplanes)) + downsample.add(norm_layer(in_channels=planes * block.expansion, + **self.norm_kwargs)) + + layers = nn.HybridSequential(prefix='layers%d_'%stage_index) + with layers.name_scope(): + if dilation in (1, 2): + layers.add(block(planes, cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + strides=strides, dilation=pre_dilation, + downsample=downsample, previous_dilation=dilation, + norm_layer=norm_layer, norm_kwargs=self.norm_kwargs, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd, avd_first=self.avd_first, + radix=self.radix, in_channels=self.inplanes, + split_drop_ratio=self.split_drop_ratio)) + elif dilation == 4: + layers.add(block(planes, cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + strides=strides, dilation=pre_dilation, + downsample=downsample, previous_dilation=dilation, + norm_layer=norm_layer, norm_kwargs=self.norm_kwargs, + last_gamma=last_gamma, dropblock_prob=dropblock_prob, + input_size=input_size, use_splat=use_splat, avd=avd, avd_first=self.avd_first, + radix=self.radix, in_channels=self.inplanes, + split_drop_ratio=self.split_drop_ratio)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + input_size = _update_input_size(input_size, strides) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.add(block(planes, cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer, + norm_kwargs=self.norm_kwargs, last_gamma=last_gamma, + dropblock_prob=dropblock_prob, input_size=input_size, + use_splat=use_splat, avd=avd, avd_first=self.avd_first, + radix=self.radix, in_channels=self.inplanes, + split_drop_ratio=self.split_drop_ratio)) + + return layers + + def hybrid_forward(self, F, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = self.flat(x) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x