Diff of /model.py [000000] .. [9ff54e]

Switch to unified view

a b/model.py
1
import sys
2
3
from torch.distributions.normal import Normal
4
import torch
5
import torch.nn as nn
6
import torch.nn.functional as F
7
8
9
class Encoder(nn.Module):
10
    def __init__(self,
11
                 in_channels,
12
                 out_channels,
13
                 kernel_size=3,
14
                 stride=1,
15
                 padding=1,
16
                 bias=True,
17
                 bn=False,
18
                 num_groups=8):
19
        super(Encoder, self).__init__()
20
        self.in_channels = in_channels
21
        self.out_channels = out_channels
22
        self.relu = nn.ReLU()
23
        self.conv1 = nn.Conv3d(in_channels,
24
                               out_channels,
25
                               kernel_size,
26
                               stride=stride,
27
                               padding=padding,
28
                               bias=bias)
29
        self.gn1 = nn.GroupNorm(num_groups, out_channels)
30
        self.conv2 = nn.Conv3d(out_channels,
31
                               out_channels,
32
                               kernel_size,
33
                               stride=stride,
34
                               padding=padding,
35
                               bias=bias)
36
        self.gn2 = nn.GroupNorm(num_groups, out_channels)
37
38
    def forward(self, x):
39
        identity = x
40
        res = self.relu(x)
41
        res = self.conv1(res)
42
        res = self.gn1(res)
43
        res = self.relu(res)
44
        res = self.conv2(res)
45
        res = self.gn2(res)
46
        res = self.relu(res)
47
        if self.in_channels != self.out_channels:
48
            pad = [0] * (2 * len(identity.size()))
49
            pad[6] = (self.out_channels - self.in_channels)
50
            identity = F.pad(input=identity, pad=pad, mode='constant', value=0)
51
        return res + identity
52
53
54
class UNet3D(nn.Module):
55
    def __init__(self,
56
                 in_channel,
57
                 n_classes,
58
                 use_bias=True,
59
                 inplanes=32,
60
                 num_groups=8):
61
        self.in_channel = in_channel
62
        self.n_classes = n_classes
63
        self.inplanes = inplanes
64
        self.num_groups = num_groups
65
        planes = [inplanes * int(pow(2, i)) for i in range(0, 5)]
66
        super(UNet3D, self).__init__()
67
        self.ec0 = Encoder(in_channel,
68
                           planes[1],
69
                           bias=use_bias,
70
                           num_groups=num_groups)
71
        self.ec1 = Encoder(planes[1],
72
                           planes[2],
73
                           bias=use_bias,
74
                           num_groups=num_groups)
75
        self.ec1_2 = Encoder(planes[2],
76
                             planes[2],
77
                             bias=use_bias,
78
                             num_groups=num_groups)
79
        self.ec2 = Encoder(planes[2],
80
                           planes[3],
81
                           bias=use_bias,
82
                           num_groups=num_groups)
83
        self.ec2_2 = Encoder(planes[3],
84
                             planes[3],
85
                             bias=use_bias,
86
                             num_groups=num_groups)
87
        self.ec3 = Encoder(planes[3],
88
                           planes[4],
89
                           bias=use_bias,
90
                           num_groups=num_groups)
91
        self.ec3_2 = Encoder(planes[4],
92
                             planes[4],
93
                             bias=use_bias,
94
                             num_groups=num_groups)
95
        self.maxpool = nn.MaxPool3d(2)
96
        self.dc3 = Encoder(planes[4],
97
                           planes[4],
98
                           bias=use_bias,
99
                           num_groups=num_groups)
100
        self.dc3_2 = Encoder(planes[4],
101
                             planes[4],
102
                             bias=use_bias,
103
                             num_groups=num_groups)
104
        self.up3 = self.decoder(planes[4],
105
                                planes[3],
106
                                kernel_size=2,
107
                                stride=2,
108
                                bias=use_bias)
109
        self.dc2 = Encoder(planes[4],
110
                           planes[3],
111
                           bias=use_bias,
112
                           num_groups=num_groups)
113
        self.dc2_2 = Encoder(planes[3],
114
                             planes[3],
115
                             bias=use_bias,
116
                             num_groups=num_groups)
117
        self.up2 = self.decoder(planes[3],
118
                                planes[2],
119
                                kernel_size=2,
120
                                stride=2,
121
                                bias=use_bias)
122
        self.dc1 = Encoder(planes[3],
123
                           planes[2],
124
                           bias=use_bias,
125
                           num_groups=num_groups)
126
        self.dc1_2 = Encoder(planes[2],
127
                             planes[2],
128
                             bias=use_bias,
129
                             num_groups=num_groups)
130
        self.up1 = self.decoder(planes[2],
131
                                planes[1],
132
                                kernel_size=2,
133
                                stride=2,
134
                                bias=use_bias)
135
        self.dc0a = Encoder(planes[2],
136
                            planes[1],
137
                            bias=use_bias,
138
                            num_groups=num_groups)
139
        self.dc0b = self.decoder(planes[1],
140
                                 n_classes,
141
                                 kernel_size=1,
142
                                 stride=1,
143
                                 bias=use_bias,
144
                                 relu=False)
145
        for m in self.modules():
146
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
147
                nn.init.kaiming_normal_(m.weight,
148
                                        mode='fan_out',
149
                                        nonlinearity='relu')
150
            elif isinstance(m, nn.GroupNorm):
151
                nn.init.constant_(m.weight, 1)
152
                nn.init.constant_(m.bias, 0)
153
154
    def decoder(self,
155
                in_channels,
156
                out_channels,
157
                kernel_size,
158
                stride=1,
159
                padding=0,
160
                output_padding=0,
161
                bias=True,
162
                relu=True):
163
        layer = [
164
            nn.ConvTranspose3d(in_channels,
165
                               out_channels,
166
                               kernel_size,
167
                               stride=stride,
168
                               padding=padding,
169
                               output_padding=output_padding,
170
                               bias=bias),
171
        ]
172
        if relu:
173
            layer.append(nn.GroupNorm(self.num_groups, out_channels))
174
            layer.append(nn.ReLU())
175
        layer = nn.Sequential(*layer)
176
        return layer
177
178
    def forward(self, x):
179
        e0 = self.ec0(x)
180
        e1 = self.ec1_2(self.ec1(self.maxpool(e0)))
181
        e2 = self.ec2_2(self.ec2(self.maxpool(e1)))
182
        e3 = self.ec3_2(self.ec3(self.maxpool(e2)))
183
        d3 = self.up3(self.dc3_2(self.dc3(e3)))
184
        if d3.size()[2:] != e2.size()[2:]:
185
            d3 = F.interpolate(d3,
186
                               e2.size()[2:],
187
                               mode='trilinear',
188
                               align_corners=False)
189
        d3 = torch.cat((d3, e2), 1)
190
        d2 = self.up2(self.dc2_2(self.dc2(d3)))
191
        if d2.size()[2:] != e1.size()[2:]:
192
            d2 = F.interpolate(d2,
193
                               e1.size()[2:],
194
                               mode='trilinear',
195
                               align_corners=False)
196
        d2 = torch.cat((d2, e1), 1)
197
        d1 = self.up1(self.dc1_2(self.dc1(d2)))
198
        if d1.size()[2:] != e0.size()[2:]:
199
            d1 = F.interpolate(d1,
200
                               e0.size()[2:],
201
                               mode='trilinear',
202
                               align_corners=False)
203
        d1 = torch.cat((d1, e0), 1)
204
        d0 = self.dc0b(self.dc0a(d1))
205
        return d0