--- a
+++ b/semseg/models.py
@@ -0,0 +1,370 @@
+from torch import nn
+import torch
+from torchvision import models
+import torchvision
+from torch.nn import functional as F
+
+
+def conv3x3(in_, out):
+    return nn.Conv2d(in_, out, 3, padding=1)
+
+
+class ConvRelu(nn.Module):
+    def __init__(self, in_: int, out: int):
+        super().__init__()
+        self.conv = conv3x3(in_, out)
+        self.activation = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.activation(x)
+        return x
+
+
+class DecoderBlock(nn.Module):
+    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
+        super(DecoderBlock, self).__init__()
+        self.in_channels = in_channels
+
+        if is_deconv:
+            self.block = nn.Sequential(
+                ConvRelu(in_channels, middle_channels),
+                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
+                                   padding=1),
+                nn.ReLU(inplace=True)
+            )
+        else:
+            self.block = nn.Sequential(
+                #Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
+                #nn.Upsample(scale_factor=2, mode='bilinear'),
+                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
+                ConvRelu(in_channels, middle_channels),
+                ConvRelu(middle_channels, out_channels),
+            )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class UNet11(nn.Module):
+    def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
+        """
+        :param num_classes:
+        :param num_filters:
+        :param pretrained:
+            False - no pre-trained network used
+            vgg - encoder pre-trained with VGG11
+        """
+        super().__init__()
+        self.pool = nn.MaxPool2d(2, 2)
+
+        self.num_classes = num_classes
+
+        self.encoder = models.vgg11(pretrained=pretrained).features
+
+        self.relu = nn.ReLU(inplace=True)
+        self.conv1 = nn.Sequential(self.encoder[0],
+                                   self.relu)
+
+        self.conv2 = nn.Sequential(self.encoder[3],
+                                   self.relu)
+
+        self.conv3 = nn.Sequential(
+            self.encoder[6],
+            self.relu,
+            self.encoder[8],
+            self.relu,
+        )
+        self.conv4 = nn.Sequential(
+            self.encoder[11],
+            self.relu,
+            self.encoder[13],
+            self.relu,
+        )
+
+        self.conv5 = nn.Sequential(
+            self.encoder[16],
+            self.relu,
+            self.encoder[18],
+            self.relu,
+        )
+
+        self.center = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
+        self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
+        self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 4, is_deconv=is_deconv)
+        self.dec3 = DecoderBlock(256 + num_filters * 4, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv)
+        self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv)
+        self.dec1 = ConvRelu(64 + num_filters, num_filters)
+
+        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
+
+    def forward(self, x):
+        conv1 = self.conv1(x)
+        conv2 = self.conv2(self.pool(conv1))
+        conv3 = self.conv3(self.pool(conv2))
+        conv4 = self.conv4(self.pool(conv3))
+        conv5 = self.conv5(self.pool(conv4))
+        center = self.center(self.pool(conv5))
+
+        dec5 = self.dec5(torch.cat([center, conv5], 1))
+        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
+        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
+        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
+        dec1 = self.dec1(torch.cat([dec2, conv1], 1))
+
+        if self.num_classes > 1:
+            x_out = F.log_softmax(self.final(dec1), dim=1)
+        else:
+            x_out = self.final(dec1)
+
+        return x_out
+
+
+class UNet16(nn.Module):
+    def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
+        """
+        :param num_classes:
+        :param num_filters:
+        :param pretrained: if encoder uses pre-trained weigths from VGG16
+        """
+        super().__init__()
+        self.num_classes = num_classes
+
+        self.pool = nn.MaxPool2d(2, 2)
+
+        self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
+
+        self.relu = nn.ReLU(inplace=True)
+
+        self.conv1 = nn.Sequential(self.encoder[0],
+                                   self.relu,
+                                   self.encoder[2],
+                                   self.relu)
+
+        self.conv2 = nn.Sequential(self.encoder[5],
+                                   self.relu,
+                                   self.encoder[7],
+                                   self.relu)
+
+        self.conv3 = nn.Sequential(self.encoder[10],
+                                   self.relu,
+                                   self.encoder[12],
+                                   self.relu,
+                                   self.encoder[14],
+                                   self.relu)
+
+        self.conv4 = nn.Sequential(self.encoder[17],
+                                   self.relu,
+                                   self.encoder[19],
+                                   self.relu,
+                                   self.encoder[21],
+                                   self.relu)
+
+        self.conv5 = nn.Sequential(self.encoder[24],
+                                   self.relu,
+                                   self.encoder[26],
+                                   self.relu,
+                                   self.encoder[28],
+                                   self.relu)
+
+        self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
+
+        self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
+        self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
+        self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv)
+        self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv)
+        self.dec1 = ConvRelu(64 + num_filters, num_filters)
+        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
+
+    def forward(self, x):
+        conv1 = self.conv1(x)
+        conv2 = self.conv2(self.pool(conv1))
+        conv3 = self.conv3(self.pool(conv2))
+        conv4 = self.conv4(self.pool(conv3))
+        conv5 = self.conv5(self.pool(conv4))
+
+        center = self.center(self.pool(conv5))
+
+        dec5 = self.dec5(torch.cat([center, conv5], 1))
+
+        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
+        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
+        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
+        dec1 = self.dec1(torch.cat([dec2, conv1], 1))
+
+        if self.num_classes > 1:
+            x_out = F.log_softmax(self.final(dec1), dim=1)
+        else:
+            x_out = self.final(dec1)
+
+        return x_out
+
+
+class Conv3BN(nn.Module):
+    def __init__(self, in_: int, out: int, bn=False):
+        super().__init__()
+        self.conv = conv3x3(in_, out)
+        self.bn = nn.BatchNorm2d(out) if bn else None
+        self.activation = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        if self.bn is not None:
+            x = self.bn(x)
+        x = self.activation(x)
+        return x
+
+
+class UNetModule(nn.Module):
+    def __init__(self, in_: int, out: int):
+        super().__init__()
+        self.l1 = Conv3BN(in_, out)
+        self.l2 = Conv3BN(out, out)
+
+    def forward(self, x):
+        x = self.l1(x)
+        x = self.l2(x)
+        return x
+
+
+class UNet(nn.Module):
+    """
+    Vanilla UNet.
+
+    Implementation from https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py
+    """
+    output_downscaled = 1
+    module = UNetModule
+
+    def __init__(self,
+                 input_channels: int = 3,
+                 filters_base: int = 32,
+                 down_filter_factors=(1, 2, 4, 8, 16),
+                 up_filter_factors=(1, 2, 4, 8, 16),
+                 bottom_s=4,
+                 num_classes=1,
+                 add_output=True):
+        super().__init__()
+        self.num_classes = num_classes
+        assert len(down_filter_factors) == len(up_filter_factors)
+        assert down_filter_factors[-1] == up_filter_factors[-1]
+        down_filter_sizes = [filters_base * s for s in down_filter_factors]
+        up_filter_sizes = [filters_base * s for s in up_filter_factors]
+        self.down, self.up = nn.ModuleList(), nn.ModuleList()
+        self.down.append(self.module(input_channels, down_filter_sizes[0]))
+        for prev_i, nf in enumerate(down_filter_sizes[1:]):
+            self.down.append(self.module(down_filter_sizes[prev_i], nf))
+        for prev_i, nf in enumerate(up_filter_sizes[1:]):
+            self.up.append(self.module(
+                down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]))
+        pool = nn.MaxPool2d(2, 2)
+        pool_bottom = nn.MaxPool2d(bottom_s, bottom_s)
+        #Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
+        #upsample = nn.Upsample(scale_factor=2)
+        #upsample_bottom = nn.Upsample(scale_factor=bottom_s)
+        #train時にエラー:align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear
+        #⇒modeをデフォルトの'nearest'から'bilinear'へ変更する、'nearest'の場合はalign_corners=Trueとできないため
+        #upsample = nn.Upsample(scale_factor=2, align_corners=True)
+        #upsample_bottom = nn.Upsample(scale_factor=bottom_s, align_corners=True)
+        upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+        upsample_bottom = nn.Upsample(scale_factor=bottom_s, mode='bilinear', align_corners=True)
+        self.downsamplers = [None] + [pool] * (len(self.down) - 1)
+        self.downsamplers[-1] = pool_bottom
+        self.upsamplers = [upsample] * len(self.up)
+        self.upsamplers[-1] = upsample_bottom
+        self.add_output = add_output
+        if add_output:
+            self.conv_final = nn.Conv2d(up_filter_sizes[0], num_classes, 1)
+
+    def forward(self, x):
+        xs = []
+        for downsample, down in zip(self.downsamplers, self.down):
+            x_in = x if downsample is None else downsample(xs[-1])
+            x_out = down(x_in)
+            xs.append(x_out)
+
+        x_out = xs[-1]
+        for x_skip, upsample, up in reversed(
+                list(zip(xs[:-1], self.upsamplers, self.up))):
+            x_out = upsample(x_out)
+            x_out = up(torch.cat([x_out, x_skip], 1))
+
+        if self.add_output:
+            x_out = self.conv_final(x_out)
+            if self.num_classes > 1:
+                x_out = F.log_softmax(x_out, dim=1)
+        return x_out
+
+
+class AlbuNet34(nn.Module):
+    """
+        UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder
+        Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
+        """
+
+    def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
+        """
+        :param num_classes:
+        :param num_filters:
+        :param pretrained:
+            False - no pre-trained network is used
+            True  - encoder is pre-trained with resnet34
+        :is_deconv:
+            False: bilinear interpolation is used in decoder
+            True: deconvolution is used in decoder
+        """
+        super().__init__()
+        self.num_classes = num_classes
+
+        self.pool = nn.MaxPool2d(2, 2)
+
+        self.encoder = torchvision.models.resnet34(pretrained=pretrained)
+
+        self.relu = nn.ReLU(inplace=True)
+
+        self.conv1 = nn.Sequential(self.encoder.conv1,
+                                   self.encoder.bn1,
+                                   self.encoder.relu,
+                                   self.pool)
+
+        self.conv2 = self.encoder.layer1
+
+        self.conv3 = self.encoder.layer2
+
+        self.conv4 = self.encoder.layer3
+
+        self.conv5 = self.encoder.layer4
+
+        self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
+
+        self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
+        self.dec4 = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
+        self.dec3 = DecoderBlock(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
+        self.dec2 = DecoderBlock(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv)
+        self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
+        self.dec0 = ConvRelu(num_filters, num_filters)
+        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
+
+    def forward(self, x):
+        conv1 = self.conv1(x)
+        conv2 = self.conv2(conv1)
+        conv3 = self.conv3(conv2)
+        conv4 = self.conv4(conv3)
+        conv5 = self.conv5(conv4)
+
+        center = self.center(self.pool(conv5))
+
+        dec5 = self.dec5(torch.cat([center, conv5], 1))
+
+        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
+        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
+        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
+        dec1 = self.dec1(dec2)
+        dec0 = self.dec0(dec1)
+
+        if self.num_classes > 1:
+            x_out = F.log_softmax(self.final(dec0), dim=1)
+        else:
+            x_out = self.final(dec0)
+
+        return x_out