###########################################################################
# Created by: Tramac
# Date: 2019-03-25
# Copyright (c) 2017
###########################################################################
"""Fast Segmentation Convolutional Neural Network"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['FastSCNN', 'get_fast_scnn']
class FastSCNN(nn.Module):
def __init__(self, num_classes, aux=False, **kwargs):
super(FastSCNN, self).__init__()
self.aux = aux
self.learning_to_downsample = LearningToDownsample(32, 48, 64)
self.global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, [3, 3, 3])
self.feature_fusion = FeatureFusionModule(64, 128, 128)
self.classifier = Classifer(128, num_classes)
if self.aux:
self.auxlayer = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Dropout(0.1),
nn.Conv2d(32, num_classes, 1)
)
def forward(self, x):
size = x.size()[2:]
higher_res_features = self.learning_to_downsample(x)
x = self.global_feature_extractor(higher_res_features)
x = self.feature_fusion(higher_res_features, x)
x = self.classifier(x)
#outputs = []
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
#outputs.append(x)
if self.aux:
auxout = self.auxlayer(higher_res_features)
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
outputs.append(auxout)
return x#tuple(outputs)
class _ConvBNReLU(nn.Module):
"""Conv-BN-ReLU"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, **kwargs):
super(_ConvBNReLU, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)
def forward(self, x):
return self.conv(x)
class _DSConv(nn.Module):
"""Depthwise Separable Convolutions"""
def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
super(_DSConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False),
nn.BatchNorm2d(dw_channels),
nn.ReLU(True),
nn.Conv2d(dw_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)
def forward(self, x):
return self.conv(x)
class _DWConv(nn.Module):
def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
super(_DWConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)
def forward(self, x):
return self.conv(x)
class LinearBottleneck(nn.Module):
"""LinearBottleneck used in MobileNetV2"""
def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs):
super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_channels == out_channels
self.block = nn.Sequential(
# pw
_ConvBNReLU(in_channels, in_channels * t, 1),
# dw
_DWConv(in_channels * t, in_channels * t, stride),
# pw-linear
nn.Conv2d(in_channels * t, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.block(x)
if self.use_shortcut:
out = x + out
return out
class PyramidPooling(nn.Module):
"""Pyramid pooling module"""
def __init__(self, in_channels, out_channels, **kwargs):
super(PyramidPooling, self).__init__()
inter_channels = int(in_channels / 4)
self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
def pool(self, x, size):
avgpool = nn.AdaptiveAvgPool2d(size)
return avgpool(x)
def upsample(self, x, size):
return F.interpolate(x, size, mode='bilinear', align_corners=True)
def forward(self, x):
size = x.size()[2:]
feat1 = self.upsample(self.conv1(self.pool(x, 1)), size)
feat2 = self.upsample(self.conv2(self.pool(x, 2)), size)
feat3 = self.upsample(self.conv3(self.pool(x, 3)), size)
feat4 = self.upsample(self.conv4(self.pool(x, 6)), size)
x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
x = self.out(x)
return x
class LearningToDownsample(nn.Module):
"""Learning to downsample module"""
def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64, **kwargs):
super(LearningToDownsample, self).__init__()
self.conv = _ConvBNReLU(3, dw_channels1, 3, 2)
self.dsconv1 = _DSConv(dw_channels1, dw_channels2, 2)
self.dsconv2 = _DSConv(dw_channels2, out_channels, 2)
def forward(self, x):
x = self.conv(x)
x = self.dsconv1(x)
x = self.dsconv2(x)
return x
class GlobalFeatureExtractor(nn.Module):
"""Global feature extractor module"""
def __init__(self, in_channels=64, block_channels=(64, 96, 128),
out_channels=128, t=6, num_blocks=(3, 3, 3), **kwargs):
super(GlobalFeatureExtractor, self).__init__()
self.bottleneck1 = self._make_layer(LinearBottleneck, in_channels, block_channels[0], num_blocks[0], t, 2)
self.bottleneck2 = self._make_layer(LinearBottleneck, block_channels[0], block_channels[1], num_blocks[1], t, 2)
self.bottleneck3 = self._make_layer(LinearBottleneck, block_channels[1], block_channels[2], num_blocks[2], t, 1)
self.ppm = PyramidPooling(block_channels[2], out_channels)
def _make_layer(self, block, inplanes, planes, blocks, t=6, stride=1):
layers = []
layers.append(block(inplanes, planes, t, stride))
for i in range(1, blocks):
layers.append(block(planes, planes, t, 1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = self.ppm(x)
return x
class FeatureFusionModule(nn.Module):
"""Feature fusion module"""
def __init__(self, highter_in_channels, lower_in_channels, out_channels, scale_factor=4, **kwargs):
super(FeatureFusionModule, self).__init__()
self.scale_factor = scale_factor
self.dwconv = _DWConv(lower_in_channels, out_channels, 1)
self.conv_lower_res = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
self.conv_higher_res = nn.Sequential(
nn.Conv2d(highter_in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
self.relu = nn.ReLU(True)
def forward(self, higher_res_feature, lower_res_feature):
lower_res_feature = F.interpolate(lower_res_feature, scale_factor=4, mode='bilinear', align_corners=True)
lower_res_feature = self.dwconv(lower_res_feature)
lower_res_feature = self.conv_lower_res(lower_res_feature)
higher_res_feature = self.conv_higher_res(higher_res_feature)
out = higher_res_feature + lower_res_feature
return self.relu(out)
class Classifer(nn.Module):
"""Classifer"""
def __init__(self, dw_channels, num_classes, stride=1, **kwargs):
super(Classifer, self).__init__()
self.dsconv1 = _DSConv(dw_channels, dw_channels, stride)
self.dsconv2 = _DSConv(dw_channels, dw_channels, stride)
self.conv = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(dw_channels, num_classes, 1)
)
def forward(self, x):
x = self.dsconv1(x)
x = self.dsconv2(x)
x = self.conv(x)
return x
def get_fast_scnn(n_classes, **kwargs):
model = FastSCNN(n_classes, **kwargs)
return model