--- a +++ b/semseg_train/models.py @@ -0,0 +1,370 @@ +from torch import nn +import torch +from torchvision import models +import torchvision +from torch.nn import functional as F + + +def conv3x3(in_, out): + return nn.Conv2d(in_, out, 3, padding=1) + + +class ConvRelu(nn.Module): + def __init__(self, in_: int, out: int): + super().__init__() + self.conv = conv3x3(in_, out) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.activation(x) + return x + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): + super(DecoderBlock, self).__init__() + self.in_channels = in_channels + + if is_deconv: + self.block = nn.Sequential( + ConvRelu(in_channels, middle_channels), + nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, + padding=1), + nn.ReLU(inplace=True) + ) + else: + self.block = nn.Sequential( + #Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details. + #nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvRelu(in_channels, middle_channels), + ConvRelu(middle_channels, out_channels), + ) + + def forward(self, x): + return self.block(x) + + +class UNet11(nn.Module): + def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False): + """ + :param num_classes: + :param num_filters: + :param pretrained: + False - no pre-trained network used + vgg - encoder pre-trained with VGG11 + """ + super().__init__() + self.pool = nn.MaxPool2d(2, 2) + + self.num_classes = num_classes + + self.encoder = models.vgg11(pretrained=pretrained).features + + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Sequential(self.encoder[0], + self.relu) + + self.conv2 = nn.Sequential(self.encoder[3], + self.relu) + + self.conv3 = nn.Sequential( + self.encoder[6], + self.relu, + self.encoder[8], + self.relu, + ) + self.conv4 = nn.Sequential( + self.encoder[11], + self.relu, + self.encoder[13], + self.relu, + ) + + self.conv5 = nn.Sequential( + self.encoder[16], + self.relu, + self.encoder[18], + self.relu, + ) + + self.center = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) + self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) + self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 4, is_deconv=is_deconv) + self.dec3 = DecoderBlock(256 + num_filters * 4, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv) + self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv) + self.dec1 = ConvRelu(64 + num_filters, num_filters) + + self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) + + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(self.pool(conv1)) + conv3 = self.conv3(self.pool(conv2)) + conv4 = self.conv4(self.pool(conv3)) + conv5 = self.conv5(self.pool(conv4)) + center = self.center(self.pool(conv5)) + + dec5 = self.dec5(torch.cat([center, conv5], 1)) + dec4 = self.dec4(torch.cat([dec5, conv4], 1)) + dec3 = self.dec3(torch.cat([dec4, conv3], 1)) + dec2 = self.dec2(torch.cat([dec3, conv2], 1)) + dec1 = self.dec1(torch.cat([dec2, conv1], 1)) + + if self.num_classes > 1: + x_out = F.log_softmax(self.final(dec1), dim=1) + else: + x_out = self.final(dec1) + + return x_out + + +class UNet16(nn.Module): + def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False): + """ + :param num_classes: + :param num_filters: + :param pretrained: if encoder uses pre-trained weigths from VGG16 + """ + super().__init__() + self.num_classes = num_classes + + self.pool = nn.MaxPool2d(2, 2) + + self.encoder = torchvision.models.vgg16(pretrained=pretrained).features + + self.relu = nn.ReLU(inplace=True) + + self.conv1 = nn.Sequential(self.encoder[0], + self.relu, + self.encoder[2], + self.relu) + + self.conv2 = nn.Sequential(self.encoder[5], + self.relu, + self.encoder[7], + self.relu) + + self.conv3 = nn.Sequential(self.encoder[10], + self.relu, + self.encoder[12], + self.relu, + self.encoder[14], + self.relu) + + self.conv4 = nn.Sequential(self.encoder[17], + self.relu, + self.encoder[19], + self.relu, + self.encoder[21], + self.relu) + + self.conv5 = nn.Sequential(self.encoder[24], + self.relu, + self.encoder[26], + self.relu, + self.encoder[28], + self.relu) + + self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) + + self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) + self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) + self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv) + self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv) + self.dec1 = ConvRelu(64 + num_filters, num_filters) + self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) + + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(self.pool(conv1)) + conv3 = self.conv3(self.pool(conv2)) + conv4 = self.conv4(self.pool(conv3)) + conv5 = self.conv5(self.pool(conv4)) + + center = self.center(self.pool(conv5)) + + dec5 = self.dec5(torch.cat([center, conv5], 1)) + + dec4 = self.dec4(torch.cat([dec5, conv4], 1)) + dec3 = self.dec3(torch.cat([dec4, conv3], 1)) + dec2 = self.dec2(torch.cat([dec3, conv2], 1)) + dec1 = self.dec1(torch.cat([dec2, conv1], 1)) + + if self.num_classes > 1: + x_out = F.log_softmax(self.final(dec1), dim=1) + else: + x_out = self.final(dec1) + + return x_out + + +class Conv3BN(nn.Module): + def __init__(self, in_: int, out: int, bn=False): + super().__init__() + self.conv = conv3x3(in_, out) + self.bn = nn.BatchNorm2d(out) if bn else None + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + x = self.activation(x) + return x + + +class UNetModule(nn.Module): + def __init__(self, in_: int, out: int): + super().__init__() + self.l1 = Conv3BN(in_, out) + self.l2 = Conv3BN(out, out) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class UNet(nn.Module): + """ + Vanilla UNet. + + Implementation from https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py + """ + output_downscaled = 1 + module = UNetModule + + def __init__(self, + input_channels: int = 3, + filters_base: int = 32, + down_filter_factors=(1, 2, 4, 8, 16), + up_filter_factors=(1, 2, 4, 8, 16), + bottom_s=4, + num_classes=1, + add_output=True): + super().__init__() + self.num_classes = num_classes + assert len(down_filter_factors) == len(up_filter_factors) + assert down_filter_factors[-1] == up_filter_factors[-1] + down_filter_sizes = [filters_base * s for s in down_filter_factors] + up_filter_sizes = [filters_base * s for s in up_filter_factors] + self.down, self.up = nn.ModuleList(), nn.ModuleList() + self.down.append(self.module(input_channels, down_filter_sizes[0])) + for prev_i, nf in enumerate(down_filter_sizes[1:]): + self.down.append(self.module(down_filter_sizes[prev_i], nf)) + for prev_i, nf in enumerate(up_filter_sizes[1:]): + self.up.append(self.module( + down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i])) + pool = nn.MaxPool2d(2, 2) + pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) + #Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details. + #upsample = nn.Upsample(scale_factor=2) + #upsample_bottom = nn.Upsample(scale_factor=bottom_s) + #train時にエラー:align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear + #⇒modeをデフォルトの'nearest'から'bilinear'へ変更する、'nearest'の場合はalign_corners=Trueとできないため + #upsample = nn.Upsample(scale_factor=2, align_corners=True) + #upsample_bottom = nn.Upsample(scale_factor=bottom_s, align_corners=True) + upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + upsample_bottom = nn.Upsample(scale_factor=bottom_s, mode='bilinear', align_corners=True) + self.downsamplers = [None] + [pool] * (len(self.down) - 1) + self.downsamplers[-1] = pool_bottom + self.upsamplers = [upsample] * len(self.up) + self.upsamplers[-1] = upsample_bottom + self.add_output = add_output + if add_output: + self.conv_final = nn.Conv2d(up_filter_sizes[0], num_classes, 1) + + def forward(self, x): + xs = [] + for downsample, down in zip(self.downsamplers, self.down): + x_in = x if downsample is None else downsample(xs[-1]) + x_out = down(x_in) + xs.append(x_out) + + x_out = xs[-1] + for x_skip, upsample, up in reversed( + list(zip(xs[:-1], self.upsamplers, self.up))): + x_out = upsample(x_out) + x_out = up(torch.cat([x_out, x_skip], 1)) + + if self.add_output: + x_out = self.conv_final(x_out) + if self.num_classes > 1: + x_out = F.log_softmax(x_out, dim=1) + return x_out + + +class AlbuNet34(nn.Module): + """ + UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder + Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/ + """ + + def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False): + """ + :param num_classes: + :param num_filters: + :param pretrained: + False - no pre-trained network is used + True - encoder is pre-trained with resnet34 + :is_deconv: + False: bilinear interpolation is used in decoder + True: deconvolution is used in decoder + """ + super().__init__() + self.num_classes = num_classes + + self.pool = nn.MaxPool2d(2, 2) + + self.encoder = torchvision.models.resnet34(pretrained=pretrained) + + self.relu = nn.ReLU(inplace=True) + + self.conv1 = nn.Sequential(self.encoder.conv1, + self.encoder.bn1, + self.encoder.relu, + self.pool) + + self.conv2 = self.encoder.layer1 + + self.conv3 = self.encoder.layer2 + + self.conv4 = self.encoder.layer3 + + self.conv5 = self.encoder.layer4 + + self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv) + + self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv) + self.dec4 = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv) + self.dec3 = DecoderBlock(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv) + self.dec2 = DecoderBlock(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv) + self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv) + self.dec0 = ConvRelu(num_filters, num_filters) + self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) + + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + conv4 = self.conv4(conv3) + conv5 = self.conv5(conv4) + + center = self.center(self.pool(conv5)) + + dec5 = self.dec5(torch.cat([center, conv5], 1)) + + dec4 = self.dec4(torch.cat([dec5, conv4], 1)) + dec3 = self.dec3(torch.cat([dec4, conv3], 1)) + dec2 = self.dec2(torch.cat([dec3, conv2], 1)) + dec1 = self.dec1(dec2) + dec0 = self.dec0(dec1) + + if self.num_classes > 1: + x_out = F.log_softmax(self.final(dec0), dim=1) + else: + x_out = self.final(dec0) + + return x_out