"""
Common routines for models in PyTorch.
"""
__all__ = ['conv1x1', 'conv3x3', 'depthwise_conv3x3', 'ConvBlock', 'conv1x1_block', 'conv3x3_block', 'conv7x7_block',
'dwconv3x3_block', 'PreConvBlock', 'pre_conv1x1_block', 'pre_conv3x3_block', 'ChannelShuffle',
'ChannelShuffle2', 'SEBlock', 'IBN', 'Identity', 'DualPathSequential', 'Concurrent', 'ParametricSequential',
'ParametricConcurrent', 'Hourglass', 'SesquialteralHourglass', 'MultiOutputSequential', 'Flatten']
import math
from inspect import isfunction
import torch
import torch.nn as nn
def conv1x1(in_channels,
out_channels,
stride=1,
groups=1,
bias=False):
"""
Convolution 1x1 layer.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
"""
return nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
groups=groups,
bias=bias)
def conv3x3(in_channels,
out_channels,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=False):
"""
Convolution 3x3 layer.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int or tuple/list of 2 int, default 1
Padding value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
"""
return nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
def depthwise_conv3x3(channels,
stride):
"""
Depthwise convolution 3x3 layer.
Parameters:
----------
channels : int
Number of input/output channels.
strides : int or tuple/list of 2 int
Strides of the convolution.
"""
return nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=stride,
padding=1,
groups=channels,
bias=False)
class ConvBlock(nn.Module):
"""
Standard convolution block with Batch normalization and ReLU/ReLU6 activation.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int or tuple/list of 2 int
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU(inplace=True)
Activation function or name of activation function.
activate : bool, default True
Whether activate the convolution block.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
groups=1,
bias=False,
bn_eps=1e-5,
activation=(lambda: nn.ReLU(inplace=True)),
activate=True):
super(ConvBlock, self).__init__()
self.activate = activate
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.bn = nn.BatchNorm2d(
num_features=out_channels,
eps=bn_eps)
if self.activate:
assert (activation is not None)
if isfunction(activation):
self.activ = activation()
elif isinstance(activation, str):
if activation == "relu":
self.activ = nn.ReLU(inplace=True)
elif activation == "relu6":
self.activ = nn.ReLU6(inplace=True)
else:
raise NotImplementedError()
else:
self.activ = activation
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.activate:
x = self.activ(x)
return x
def conv1x1_block(in_channels,
out_channels,
stride=1,
padding=0,
groups=1,
bias=False,
bn_eps=1e-5,
activation=(lambda: nn.ReLU(inplace=True)),
activate=True):
"""
1x1 version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int or tuple/list of 2 int, default 0
Padding value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU(inplace=True)
Activation function or name of activation function.
activate : bool, default True
Whether activate the convolution block.
"""
return ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding,
groups=groups,
bias=bias,
bn_eps=bn_eps,
activation=activation,
activate=activate)
def conv3x3_block(in_channels,
out_channels,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=False,
bn_eps=1e-5,
activation=(lambda: nn.ReLU(inplace=True)),
activate=True):
"""
3x3 version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int or tuple/list of 2 int, default 1
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU(inplace=True)
Activation function or name of activation function.
activate : bool, default True
Whether activate the convolution block.
"""
return ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
bn_eps=bn_eps,
activation=activation,
activate=activate)
def conv7x7_block(in_channels,
out_channels,
stride=1,
padding=3,
bias=False,
activation=(lambda: nn.ReLU(inplace=True)),
activate=True):
"""
7x7 version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int or tuple/list of 2 int, default 3
Padding value for convolution layer.
bias : bool, default False
Whether the layer uses a bias vector.
activation : function or str or None, default nn.ReLU(inplace=True)
Activation function or name of activation function.
activate : bool, default True
Whether activate the convolution block.
"""
return ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=7,
stride=stride,
padding=padding,
bias=bias,
activation=activation,
activate=activate)
def dwconv3x3_block(in_channels,
out_channels,
stride,
padding=1,
dilation=1,
bias=False,
activation=(lambda: nn.ReLU(inplace=True)),
activate=True):
"""
3x3 depthwise version of the standard convolution block with ReLU6 activation.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int or tuple/list of 2 int, default 1
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
bias : bool, default False
Whether the layer uses a bias vector.
activation : function or str or None, default nn.ReLU(inplace=True)
Activation function or name of activation function.
activate : bool, default True
Whether activate the convolution block.
"""
return conv3x3_block(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
padding=padding,
dilation=dilation,
groups=out_channels,
bias=bias,
activation=activation,
activate=activate)
class PreConvBlock(nn.Module):
"""
Convolution block with Batch normalization and ReLU pre-activation.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int or tuple/list of 2 int
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
bias : bool, default False
Whether the layer uses a bias vector.
return_preact : bool, default False
Whether return pre-activation. It's used by PreResNet.
activate : bool, default True
Whether activate the convolution block.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
bias=False,
return_preact=False,
activate=True):
super(PreConvBlock, self).__init__()
self.return_preact = return_preact
self.activate = activate
self.bn = nn.BatchNorm2d(num_features=in_channels)
if self.activate:
self.activ = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def forward(self, x):
x = self.bn(x)
if self.activate:
x = self.activ(x)
if self.return_preact:
x_pre_activ = x
x = self.conv(x)
if self.return_preact:
return x, x_pre_activ
else:
return x
def pre_conv1x1_block(in_channels,
out_channels,
stride=1,
bias=False,
return_preact=False,
activate=True):
"""
1x1 version of the pre-activated convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
bias : bool, default False
Whether the layer uses a bias vector.
return_preact : bool, default False
Whether return pre-activation.
activate : bool, default True
Whether activate the convolution block.
"""
return PreConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=bias,
return_preact=return_preact,
activate=activate)
def pre_conv3x3_block(in_channels,
out_channels,
stride=1,
padding=1,
dilation=1,
return_preact=False,
activate=True):
"""
3x3 version of the pre-activated convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int or tuple/list of 2 int, default 1
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
return_preact : bool, default False
Whether return pre-activation.
activate : bool, default True
Whether activate the convolution block.
"""
return PreConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=padding,
dilation=dilation,
return_preact=return_preact,
activate=activate)
def channel_shuffle(x,
groups):
"""
Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices,'
https://arxiv.org/abs/1707.01083.
Parameters:
----------
x : Tensor
Input tensor.
groups : int
Number of groups.
Returns
-------
Tensor
Resulted tensor.
"""
batch, channels, height, width = x.size()
# assert (channels % groups == 0)
channels_per_group = channels // groups
x = x.view(batch, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch, channels, height, width)
return x
class ChannelShuffle(nn.Module):
"""
Channel shuffle layer. This is a wrapper over the same operation. It is designed to save the number of groups.
Parameters:
----------
channels : int
Number of channels.
groups : int
Number of groups.
"""
def __init__(self,
channels,
groups):
super(ChannelShuffle, self).__init__()
# assert (channels % groups == 0)
if channels % groups != 0:
raise ValueError('channels must be divisible by groups')
self.groups = groups
def forward(self, x):
return channel_shuffle(x, self.groups)
def channel_shuffle2(x,
groups):
"""
Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices,'
https://arxiv.org/abs/1707.01083. The alternative version.
Parameters:
----------
x : Tensor
Input tensor.
groups : int
Number of groups.
Returns
-------
Tensor
Resulted tensor.
"""
batch, channels, height, width = x.size()
# assert (channels % groups == 0)
channels_per_group = channels // groups
x = x.view(batch, channels_per_group, groups, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch, channels, height, width)
return x
class ChannelShuffle2(nn.Module):
"""
Channel shuffle layer. This is a wrapper over the same operation. It is designed to save the number of groups.
The alternative version.
Parameters:
----------
channels : int
Number of channels.
groups : int
Number of groups.
"""
def __init__(self,
channels,
groups):
super(ChannelShuffle2, self).__init__()
# assert (channels % groups == 0)
if channels % groups != 0:
raise ValueError('channels must be divisible by groups')
self.groups = groups
def forward(self, x):
return channel_shuffle2(x, self.groups)
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation block from 'Squeeze-and-Excitation Networks,' https://arxiv.org/abs/1709.01507.
Parameters:
----------
channels : int
Number of channels.
reduction : int, default 16
Squeeze reduction value.
"""
def __init__(self,
channels,
reduction=16):
super(SEBlock, self).__init__()
mid_cannels = channels // reduction
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.conv1 = conv1x1(
in_channels=channels,
out_channels=mid_cannels,
bias=True)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv1x1(
in_channels=mid_cannels,
out_channels=channels,
bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
w = self.pool(x)
w = self.conv1(w)
w = self.relu(w)
w = self.conv2(w)
w = self.sigmoid(w)
x = x * w
return x
class IBN(nn.Module):
"""
Instance-Batch Normalization block from 'Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net,'
https://arxiv.org/abs/1807.09441.
Parameters:
----------
channels : int
Number of channels.
inst_fraction : float, default 0.5
The first fraction of channels for normalization.
inst_first : bool, default True
Whether instance normalization be on the first part of channels.
"""
def __init__(self,
channels,
first_fraction=0.5,
inst_first=True):
super(IBN, self).__init__()
self.inst_first = inst_first
h1_channels = int(math.floor(channels * first_fraction))
h2_channels = channels - h1_channels
self.split_sections = [h1_channels, h2_channels]
if self.inst_first:
self.inst_norm = nn.InstanceNorm2d(
num_features=h1_channels,
affine=True)
self.batch_norm = nn.BatchNorm2d(num_features=h2_channels)
else:
self.batch_norm = nn.BatchNorm2d(num_features=h1_channels)
self.inst_norm = nn.InstanceNorm2d(
num_features=h2_channels,
affine=True)
def forward(self, x):
x1, x2 = torch.split(x, split_size_or_sections=self.split_sections, dim=1)
if self.inst_first:
x1 = self.inst_norm(x1.contiguous())
x2 = self.batch_norm(x2.contiguous())
else:
x1 = self.batch_norm(x1.contiguous())
x2 = self.inst_norm(x2.contiguous())
x = torch.cat((x1, x2), dim=1)
return x
class Identity(nn.Module):
"""
Identity block.
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class DualPathSequential(nn.Sequential):
"""
A sequential container for modules with dual inputs/outputs.
Modules will be executed in the order they are added.
Parameters:
----------
return_two : bool, default True
Whether to return two output after execution.
first_ordinals : int, default 0
Number of the first modules with single input/output.
last_ordinals : int, default 0
Number of the final modules with single input/output.
dual_path_scheme : function
Scheme of dual path response for a module.
dual_path_scheme_ordinal : function
Scheme of dual path response for an ordinal module.
"""
def __init__(self,
return_two=True,
first_ordinals=0,
last_ordinals=0,
dual_path_scheme=(lambda module, x1, x2: module(x1, x2)),
dual_path_scheme_ordinal=(lambda module, x1, x2: (module(x1), x2))):
super(DualPathSequential, self).__init__()
self.return_two = return_two
self.first_ordinals = first_ordinals
self.last_ordinals = last_ordinals
self.dual_path_scheme = dual_path_scheme
self.dual_path_scheme_ordinal = dual_path_scheme_ordinal
def forward(self, x1, x2=None):
length = len(self._modules.values())
for i, module in enumerate(self._modules.values()):
if (i < self.first_ordinals) or (i >= length - self.last_ordinals):
x1, x2 = self.dual_path_scheme_ordinal(module, x1, x2)
else:
x1, x2 = self.dual_path_scheme(module, x1, x2)
if self.return_two:
return x1, x2
else:
return x1
class Concurrent(nn.Sequential):
"""
A container for concatenation of modules on the base of the sequential container.
Parameters:
----------
axis : int, default 1
The axis on which to concatenate the outputs.
stack : bool, default False
Whether to concatenate tensors along a new dimension.
"""
def __init__(self,
axis=1,
stack=False):
super(Concurrent, self).__init__()
self.axis = axis
self.stack = stack
def forward(self, x):
out = []
for module in self._modules.values():
out.append(module(x))
if self.stack:
out = torch.stack(tuple(out), dim=self.axis)
else:
out = torch.cat(tuple(out), dim=self.axis)
return out
class ParametricSequential(nn.Sequential):
"""
A sequential container for modules with parameters.
Modules will be executed in the order they are added.
"""
def __init__(self, *args):
super(ParametricSequential, self).__init__(*args)
def forward(self, x, **kwargs):
for module in self._modules.values():
x = module(x, **kwargs)
return x
class ParametricConcurrent(nn.Sequential):
"""
A container for concatenation of modules with parameters.
Parameters:
----------
axis : int, default 1
The axis on which to concatenate the outputs.
"""
def __init__(self, axis=1):
super(ParametricConcurrent, self).__init__()
self.axis = axis
def forward(self, x, **kwargs):
out = []
for module in self._modules.values():
out.append(module(x, **kwargs))
out = torch.cat(tuple(out), dim=self.axis)
return out
class Hourglass(nn.Module):
"""
A hourglass block.
Parameters:
----------
down_seq : nn.Sequential
Down modules as sequential.
up_seq : nn.Sequential
Up modules as sequential.
skip_seq : nn.Sequential
Skip connection modules as sequential.
merge_type : str, default 'add'
Type of concatenation of up and skip outputs.
return_first_skip : bool, default False
Whether return the first skip connection output. Used in ResAttNet.
"""
def __init__(self,
down_seq,
up_seq,
skip_seq,
merge_type="add",
return_first_skip=False):
super(Hourglass, self).__init__()
assert (len(up_seq) == len(down_seq))
assert (len(skip_seq) == len(down_seq))
assert (merge_type in ["add"])
self.merge_type = merge_type
self.return_first_skip = return_first_skip
self.depth = len(down_seq)
self.down_seq = down_seq
self.up_seq = up_seq
self.skip_seq = skip_seq
def forward(self, x, **kwargs):
y = None
down_outs = [x]
for down_module in self.down_seq._modules.values():
x = down_module(x)
down_outs.append(x)
for i in range(len(down_outs)):
if i != 0:
y = down_outs[self.depth - i]
skip_module = self.skip_seq[self.depth - i]
y = skip_module(y)
if (y is not None) and (self.merge_type == "add"):
x = x + y
if i != len(down_outs) - 1:
up_module = self.up_seq[self.depth - 1 - i]
x = up_module(x)
if self.return_first_skip:
return x, y
else:
return x
class SesquialteralHourglass(nn.Module):
"""
A sesquialteral hourglass block.
Parameters:
----------
down1_seq : nn.Sequential
The first down modules as sequential.
skip1_seq : nn.Sequential
The first skip connection modules as sequential.
up_seq : nn.Sequential
Up modules as sequential.
skip2_seq : nn.Sequential
The second skip connection modules as sequential.
down2_seq : nn.Sequential
The second down modules as sequential.
merge_type : str, default 'con'
Type of concatenation of up and skip outputs.
"""
def __init__(self,
down1_seq,
skip1_seq,
up_seq,
skip2_seq,
down2_seq,
merge_type="cat"):
super(SesquialteralHourglass, self).__init__()
assert (len(down1_seq) == len(up_seq))
assert (len(down1_seq) == len(down2_seq))
assert (len(skip1_seq) == len(skip2_seq))
assert (len(down1_seq) == len(skip1_seq) - 1)
assert (merge_type in ["cat", "add"])
self.merge_type = merge_type
self.depth = len(down1_seq)
self.down1_seq = down1_seq
self.skip1_seq = skip1_seq
self.up_seq = up_seq
self.skip2_seq = skip2_seq
self.down2_seq = down2_seq
def _merge(self, x, y):
if y is not None:
if self.merge_type == "cat":
x = torch.cat((x, y), dim=1)
elif self.merge_type == "add":
x = x + y
return x
def forward(self, x, **kwargs):
y = self.skip1_seq[0](x)
skip1_outs = [y]
for i in range(self.depth):
x = self.down1_seq[i](x)
y = self.skip1_seq[i + 1](x)
skip1_outs.append(y)
x = skip1_outs[self.depth]
y = self.skip2_seq[0](x)
skip2_outs = [y]
for i in range(self.depth):
x = self.up_seq[i](x)
y = skip1_outs[self.depth - 1 - i]
x = self._merge(x, y)
y = self.skip2_seq[i + 1](x)
skip2_outs.append(y)
x = self.skip2_seq[self.depth](x)
for i in range(self.depth):
x = self.down2_seq[i](x)
y = skip2_outs[self.depth - 1 - i]
x = self._merge(x, y)
return x
class MultiOutputSequential(nn.Sequential):
"""
A sequential container with multiple outputs.
Modules will be executed in the order they are added.
"""
def __init__(self):
super(MultiOutputSequential, self).__init__()
def forward(self, x):
outs = []
for module in self._modules.values():
x = module(x)
if hasattr(module, "do_output") and module.do_output:
outs.append(x)
return [x] + outs
class Flatten(nn.Module):
"""
Simple flatten module.
"""
def forward(self, x):
return x.view(x.size(0), -1)