Diff of /pathflowai/unet.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/unet.py
1
# From https://raw.githubusercontent.com/milesial/Pytorch-UNet/master/unet/unet_model.py
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
7
8
class double_conv(nn.Module):
9
    '''(conv => BN => ReLU) * 2'''
10
    def __init__(self, in_ch, out_ch):
11
        super(double_conv, self).__init__()
12
        self.conv = nn.Sequential(
13
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
14
            nn.BatchNorm2d(out_ch),
15
            nn.ReLU(inplace=True),
16
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
17
            nn.BatchNorm2d(out_ch),
18
            nn.ReLU(inplace=True)
19
        )
20
21
    def forward(self, x):
22
        x = self.conv(x)
23
        return x
24
25
26
class inconv(nn.Module):
27
    def __init__(self, in_ch, out_ch):
28
        super(inconv, self).__init__()
29
        self.conv = double_conv(in_ch, out_ch)
30
31
    def forward(self, x):
32
        x = self.conv(x)
33
        return x
34
35
36
class down(nn.Module):
37
    def __init__(self, in_ch, out_ch):
38
        super(down, self).__init__()
39
        self.mpconv = nn.Sequential(
40
            nn.MaxPool2d(2),
41
            double_conv(in_ch, out_ch)
42
        )
43
44
    def forward(self, x):
45
        x = self.mpconv(x)
46
        return x
47
48
49
class up(nn.Module):
50
    def __init__(self, in_ch, out_ch, bilinear=True):
51
        super(up, self).__init__()
52
53
        #  would be a nice idea if the upsampling could be learned too,
54
        #  but my machine do not have enough memory to handle all those weights
55
        if bilinear:
56
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
57
        else:
58
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
59
60
        self.conv = double_conv(in_ch, out_ch)
61
62
    def forward(self, x1, x2):
63
        x1 = self.up(x1)
64
65
        # input is CHW
66
        diffY = x2.size()[2] - x1.size()[2]
67
        diffX = x2.size()[3] - x1.size()[3]
68
69
        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
70
                        diffY // 2, diffY - diffY//2))
71
72
        # for padding issues, see
73
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
74
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
75
76
        x = torch.cat([x2, x1], dim=1)
77
        x = self.conv(x)
78
        return x
79
80
81
class outconv(nn.Module):
82
    def __init__(self, in_ch, out_ch):
83
        super(outconv, self).__init__()
84
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
85
86
    def forward(self, x):
87
        x = self.conv(x)
88
        return x
89
90
class UNet(nn.Module):
91
    def __init__(self, n_channels, n_classes, use_sigmoid=False, use_softmax=False):
92
        super(UNet, self).__init__()
93
        self.inc = inconv(n_channels, 64)
94
        self.down1 = down(64, 128)
95
        self.down2 = down(128, 256)
96
        self.down3 = down(256, 512)
97
        self.down4 = down(512, 512)
98
        self.up1 = up(1024, 256)
99
        self.up2 = up(512, 128)
100
        self.up3 = up(256, 64)
101
        self.up4 = up(128, 64)
102
        self.outc = outconv(64, n_classes)
103
        self.sigmoid = nn.Sequential(nn.Sigmoid() if use_sigmoid else nn.Dropout(p=0.),nn.LogSoftmax(dim=1) if use_softmax else nn.Dropout(p=0.))
104
105
    def forward(self, x):
106
        x1 = self.inc(x)
107
        x2 = self.down1(x1)
108
        x3 = self.down2(x2)
109
        x4 = self.down3(x3)
110
        x5 = self.down4(x4)
111
        x = self.up1(x5, x4)
112
        x = self.up2(x, x3)
113
        x = self.up3(x, x2)
114
        x = self.up4(x, x1)
115
        x = self.outc(x)
116
        return self.sigmoid(x)