--- a +++ b/model.py @@ -0,0 +1,100 @@ +""" +UNet +The main UNet model implementation +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Utility Functions +''' when filter kernel= 3x3, padding=1 makes in&out matrix same size''' +def conv_bn_leru(in_channels, out_channels, kernel_size=3, stride=1, padding=1): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + +def down_pooling(): + return nn.MaxPool2d(2) + +def up_pooling(in_channels, out_channels, kernel_size=2, stride=2): + return nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + +# UNet class + +class UNet(nn.Module): + def __init__(self, input_channels, nclasses): + super().__init__() + # go down + self.conv1 = conv_bn_leru(input_channels,64) + self.conv2 = conv_bn_leru(64, 128) + self.conv3 = conv_bn_leru(128, 256) + self.conv4 = conv_bn_leru(256, 512) + self.conv5 = conv_bn_leru(512, 1024) + self.down_pooling = nn.MaxPool2d(2) + + # go up + self.up_pool6 = up_pooling(1024, 512) + self.conv6 = conv_bn_leru(1024, 512) + self.up_pool7 = up_pooling(512, 256) + self.conv7 = conv_bn_leru(512, 256) + self.up_pool8 = up_pooling(256, 128) + self.conv8 = conv_bn_leru(256, 128) + self.up_pool9 = up_pooling(128, 64) + self.conv9 = conv_bn_leru(128, 64) + + self.conv10 = nn.Conv2d(64, nclasses, 1) + + + # test weight init + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_out') + if m.bias is not None: + m.bias.data.zero_() + + + def forward(self, x): + # go down + x1 = self.conv1(x) + p1 = self.down_pooling(x1) + x2 = self.conv2(p1) + p2 = self.down_pooling(x2) + x3 = self.conv3(p2) + p3 = self.down_pooling(x3) + x4 = self.conv4(p3) + p4 = self.down_pooling(x4) + x5 = self.conv5(p4) + + # go up + p6 = self.up_pool6(x5) + x6 = torch.cat([p6, x4], dim=1) + x6 = self.conv6(x6) + + p7 = self.up_pool7(x6) + x7 = torch.cat([p7, x3], dim=1) + x7 = self.conv7(x7) + + p8 = self.up_pool8(x7) + x8 = torch.cat([p8, x2], dim=1) + x8 = self.conv8(x8) + + p9 = self.up_pool9(x8) + x9 = torch.cat([p9, x1], dim=1) + x9 = self.conv9(x9) + + + output = self.conv10(x9) + output = F.sigmoid(output) + + return output