a b/model.py
1
# -*- coding: utf-8 -*-
2
from __future__ import print_function, division
3
4
import torch
5
import torch.nn as nn
6
import numpy as np
7
from torch.utils.checkpoint import checkpoint
8
9
10
class UnetBlock_Encode(nn.Module):
11
    def __init__(self, in_channels, out_channel):
12
        super(UnetBlock_Encode, self).__init__()
13
14
        self.in_chns = in_channels
15
        self.out_chns = out_channel
16
17
        self.conv1 = nn.Sequential(
18
            nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3),
19
                      padding=(0, 0, 1)),
20
            nn.BatchNorm3d(self.out_chns),
21
            nn.ReLU(inplace=True)
22
        )
23
24
        self.conv2_1 = nn.Sequential(
25
            nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1),
26
                      padding=(1, 1, 0), groups=1),
27
            nn.BatchNorm3d(self.out_chns),
28
            nn.ReLU(inplace=True),
29
            nn.Dropout(p=0.2)
30
        )
31
32
        self.conv2_2 = nn.Sequential(
33
            nn.AvgPool3d(kernel_size=4, stride=2, padding=1),
34
            nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1,
35
                      padding=0),
36
            nn.BatchNorm3d(self.out_chns),
37
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
38
        )
39
40
    def forward(self, x):
41
        x = self.conv1(x)
42
        x1 = self.conv2_1(x)
43
        x2 = self.conv2_2(x)
44
        x2 = torch.sigmoid(x2)
45
46
        x = x1 + x2 * x
47
        return x
48
49
class UnetBlock_Encode_4(nn.Module):
50
    def __init__(self, in_channels, out_channel):
51
        super(UnetBlock_Encode_4, self).__init__()
52
53
        self.in_chns = in_channels
54
        self.out_chns = out_channel
55
56
        self.conv1 = nn.Sequential(
57
            nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3),
58
                      padding=(0, 0, 1)),
59
            nn.BatchNorm3d(self.out_chns),
60
            nn.ReLU(inplace=True)
61
        )
62
63
        self.conv2_1 = nn.Sequential(
64
            nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1),
65
                      padding=(1, 1, 0), groups=self.out_chns),
66
            nn.BatchNorm3d(self.out_chns),
67
            nn.ReLU(inplace=True),
68
            nn.Dropout(p=0.2)
69
        )
70
71
        self.conv2_2 = nn.Sequential(
72
            nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1,
73
                      padding=0),
74
            nn.BatchNorm3d(self.out_chns)
75
        )
76
77
    def forward(self, x):
78
        x = self.conv1(x)
79
        x1 = self.conv2_1(x)
80
        x2 = self.conv2_2(x)
81
        x2 = torch.sigmoid(x2)
82
        x = x1 + x2 * x
83
        return x
84
85
86
87
class UnetBlock_Down(nn.Module):
88
    def __init__(self, in_channels, out_channel):
89
        super(UnetBlock_Down, self).__init__()
90
        self.avg_pool = nn.AvgPool3d(kernel_size=2)
91
92
    def forward(self, x):
93
        x = self.avg_pool(x)
94
        return x
95
96
class UnetBlock_Up(nn.Module):
97
    def __init__(self, in_channels, out_channel):
98
        super(UnetBlock_Up, self).__init__()
99
        self.conv = self.conv1 = nn.Sequential(
100
            nn.Conv3d(in_channels, out_channel, kernel_size=1,
101
                      padding=0, groups=1),
102
            nn.BatchNorm3d(out_channel),
103
            nn.ReLU(inplace=True),
104
            nn.Dropout(p=0.2)
105
        )
106
107
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners = False)
108
109
    def forward(self, x):
110
        x = self.conv(x)
111
        x = self.up(x)
112
        return x
113
114
class UNet_Seg(nn.Module):
115
    def __init__(self, C_in=1, n_classes=1):
116
        super(UNet_Seg, self).__init__()
117
        self.in_chns = C_in
118
        self.n_class = n_classes
119
        inchn = 32
120
        self.ft_chns = [inchn, inchn*2, inchn*4, inchn*8]
121
        self.resolution_level = len(self.ft_chns)
122
123
        self.block1 = UnetBlock_Encode(self.in_chns, self.ft_chns[0])
124
125
        self.block2 = UnetBlock_Encode(self.ft_chns[0], self.ft_chns[1])
126
127
        self.block3 = UnetBlock_Encode(self.ft_chns[1], self.ft_chns[2])
128
129
        self.block4 = UnetBlock_Encode_4(self.ft_chns[2], self.ft_chns[3])
130
131
        self.block5 = UnetBlock_Encode(2*self.ft_chns[2], self.ft_chns[2])
132
133
        self.block6 = UnetBlock_Encode(2*self.ft_chns[1], self.ft_chns[1])
134
135
        self.block7 = UnetBlock_Encode(2*self.ft_chns[0], self.ft_chns[0])
136
137
        self.down1 = UnetBlock_Down(self.ft_chns[0], self.ft_chns[0])
138
139
        self.down2 = UnetBlock_Down(self.ft_chns[1], self.ft_chns[1])
140
141
        self.down3 = UnetBlock_Down(self.ft_chns[2], self.ft_chns[2])
142
143
        self.up1 = UnetBlock_Up(self.ft_chns[3], self.ft_chns[2])
144
145
        self.up2 = UnetBlock_Up(self.ft_chns[2], self.ft_chns[1])
146
147
        self.up3 = UnetBlock_Up(self.ft_chns[1], self.ft_chns[0])
148
149
        self.conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1)
150
151
    def forward(self, x):
152
        f1 = self.block1(x)
153
        d1 = self.down1(f1)
154
155
        f2 = self.block2(d1)
156
        d2 = self.down2(f2)
157
158
        f3 = self.block3(d2)
159
        d3 = self.down3(f3)
160
161
        f4 = self.block4(d3)
162
163
        f4up = self.up1(f4)
164
        f3cat = torch.cat((f3, f4up), dim=1)
165
        f5 = self.block5(f3cat)
166
167
        f5up = self.up2(f5)
168
        f2cat = torch.cat((f2, f5up), dim=1)
169
        f6 = self.block6(f2cat)
170
171
        f6up = self.up3(f6)
172
        f1cat = torch.cat((f1, f6up), dim=1)
173
        f7 = self.block7(f1cat)
174
175
        f8 = self.conv(f7)
176
177
        output = torch.sigmoid(f8)
178
        return output
179
180
class LCOVNet(nn.Module):
181
    def __init__(self, input_channels, n_classes):
182
        super(LCOVNet, self).__init__()
183
        self.seg_network = UNet_Seg(input_channels, n_classes)
184
185
    def seg(self, x):
186
        output = self.seg_network(x)
187
188
    def forward(self, x):
189
        x = x + torch.zeros_like(x, dtype=x.dtype, device=x.device, requires_grad=True)
190
        output = checkpoint(self.seg_network, x)
191
192
        return output