Diff of /model/Models.py [000000] .. [f77492]

Switch to unified view

a b/model/Models.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from torch.nn import init
5
from ._utils import *
6
from math import sqrt
7
8
def init_weights(net, init_type='normal', gain=0.02):
9
    def init_func(m):
10
        classname = m.__class__.__name__
11
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
12
            if init_type == 'normal':
13
                init.normal_(m.weight.data, 0.0, gain)
14
            elif init_type == 'xavier':
15
                init.xavier_normal_(m.weight.data, gain=gain)
16
            elif init_type == 'kaiming':
17
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
18
            elif init_type == 'orthogonal':
19
                init.orthogonal_(m.weight.data, gain=gain)
20
            else:
21
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
22
            if hasattr(m, 'bias') and m.bias is not None:
23
                init.constant_(m.bias.data, 0.0)
24
        elif classname.find('BatchNorm2d') != -1:
25
            init.normal_(m.weight.data, 1.0, gain)
26
            init.constant_(m.bias.data, 0.0)
27
28
    print('initialize network with %s' % init_type)
29
    net.apply(init_func)
30
31
class DoubleConv(nn.Module):
32
    """
33
    Double Conv for U-Net
34
    """
35
    def __init__(self, in_ch, out_ch, k_1=3, k_2=3):
36
        super(DoubleConv, self).__init__()
37
        padding_1 = cal_same_padding(k_1)
38
        padding_2 = cal_same_padding(k_2)
39
        self.conv = nn.Sequential(
40
            nn.Conv2d(in_ch, out_ch, k_1, padding=padding_1),  # in_ch、out_ch是通道数
41
            nn.BatchNorm2d(out_ch),
42
            nn.ReLU(inplace=True),
43
            # Mish(),
44
            nn.Conv2d(out_ch, out_ch, k_2, padding=padding_2),
45
            nn.BatchNorm2d(out_ch),
46
            nn.ReLU(inplace=True)
47
        )
48
49
        for m in self.modules():
50
            if isinstance(m, nn.Conv2d):
51
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
52
                m.weight.data.normal_(0, sqrt(2. / n))
53
                if m.bias is not None:
54
                    m.bias.data.zero_()
55
            elif isinstance(m, nn.BatchNorm2d):
56
                m.weight.data.normal_(1.0, 0.02)
57
                m.bias.data.fill_(0)
58
59
    def forward(self, x):
60
        return self.conv(x)
61
62
class conv_block(nn.Module):
63
    def __init__(self,ch_in,ch_out):
64
        super(conv_block,self).__init__()
65
        self.conv = nn.Sequential(
66
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
67
            nn.BatchNorm2d(ch_out),
68
            nn.ReLU(inplace=True),
69
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
70
            nn.BatchNorm2d(ch_out),
71
            nn.ReLU(inplace=True)
72
        )
73
74
75
    def forward(self,x):
76
        x = self.conv(x)
77
        return x
78
79
class up_conv(nn.Module):
80
    def __init__(self,ch_in,ch_out):
81
        super(up_conv,self).__init__()
82
        self.up = nn.Sequential(
83
            nn.Upsample(scale_factor=2),
84
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
85
            nn.BatchNorm2d(ch_out),
86
            nn.ReLU(inplace=True)
87
        )
88
89
    def forward(self,x):
90
        x = self.up(x)
91
        return x
92
93
class Recurrent_block(nn.Module):
94
    def __init__(self,ch_out,t=2):
95
        super(Recurrent_block,self).__init__()
96
        self.t = t
97
        self.ch_out = ch_out
98
        self.conv = nn.Sequential(
99
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
100
            nn.BatchNorm2d(ch_out),
101
            nn.ReLU(inplace=True)
102
        )
103
104
    def forward(self,x):
105
        for i in range(self.t):
106
107
            if i==0:
108
                x1 = self.conv(x)
109
            
110
            x1 = self.conv(x+x1)
111
        return x1
112
        
113
class RRCNN_block(nn.Module):
114
    def __init__(self,ch_in,ch_out,t=2):
115
        super(RRCNN_block,self).__init__()
116
        self.RCNN = nn.Sequential(
117
            Recurrent_block(ch_out,t=t),
118
            Recurrent_block(ch_out,t=t)
119
        )
120
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
121
122
    def forward(self,x):
123
        x = self.Conv_1x1(x)
124
        x1 = self.RCNN(x)
125
        return x+x1
126
127
class single_conv(nn.Module):
128
    def __init__(self,ch_in,ch_out):
129
        super(single_conv,self).__init__()
130
        self.conv = nn.Sequential(
131
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
132
            nn.BatchNorm2d(ch_out),
133
            nn.ReLU(inplace=True)
134
        )
135
136
    def forward(self,x):
137
        x = self.conv(x)
138
        return x
139
140
class Attention_block(nn.Module):
141
    def __init__(self,F_g,F_l,F_int):
142
        super(Attention_block,self).__init__()
143
        self.W_g = nn.Sequential(
144
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
145
            nn.BatchNorm2d(F_int)
146
            )
147
        
148
        self.W_x = nn.Sequential(
149
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
150
            nn.BatchNorm2d(F_int)
151
        )
152
153
        self.psi = nn.Sequential(
154
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
155
            nn.BatchNorm2d(1),
156
            nn.Sigmoid()
157
        )
158
        
159
        self.relu = nn.ReLU(inplace=True)
160
        
161
    def forward(self,g,x):
162
        g1 = self.W_g(g)
163
        x1 = self.W_x(x)
164
        psi = self.relu(g1+x1)
165
        psi = self.psi(psi)
166
167
        return x*psi
168
169
class U_Net(nn.Module):
170
    def __init__(self, img_ch=3, out_dim=1):
171
172
        super(U_Net, self).__init__()
173
        self.conv1 = DoubleConv(img_ch, 64)
174
        self.pool1 = nn.MaxPool2d(2)
175
        self.conv2 = DoubleConv(64, 128)
176
        self.pool2 = nn.MaxPool2d(2)
177
        self.conv3 = DoubleConv(128, 256)
178
        self.pool3 = nn.MaxPool2d(2)
179
        self.conv4 = DoubleConv(256, 512)
180
        self.pool4 = nn.MaxPool2d(2)
181
        self.conv5 = DoubleConv(512, 1024)
182
183
184
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
185
        self.conv6 = DoubleConv(1024, 512)
186
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
187
        self.conv7 = DoubleConv(512, 256)
188
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
189
        self.conv8 = DoubleConv(256, 128)
190
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
191
        self.conv9 = DoubleConv(128, 64)
192
        self.conv10 = nn.Conv2d(64, out_dim, 1)
193
194
195
        for m in self.modules():
196
            if isinstance(m, nn.Conv2d):
197
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
198
                m.weight.data.normal_(0, sqrt(2. / n))
199
                if m.bias is not None:
200
                    m.bias.data.zero_()
201
            elif isinstance(m, nn.BatchNorm2d):
202
                m.weight.data.normal_(1.0, 0.02)
203
                m.bias.data.fill_(0)
204
205
    def forward(self, inputs):
206
        c1 = self.conv1(inputs)
207
        p1 = self.pool1(c1)
208
        c2 = self.conv2(p1)
209
        p2 = self.pool2(c2)
210
        c3 = self.conv3(p2)
211
        p3 = self.pool3(c3)
212
        c4 = self.conv4(p3)
213
        p4 = self.pool4(c4)
214
        c5 = self.conv5(p4)
215
216
        up_6 = self.up6(c5)
217
        merge6 = torch.cat([up_6, c4], dim=1)
218
        c6 = self.conv6(merge6)  
219
        up_7 = self.up7(c6)
220
        merge7 = torch.cat([up_7, c3], dim=1)
221
        c7 = self.conv7(merge7)
222
        up_8 = self.up8(c7)
223
        merge8 = torch.cat([up_8, c2], dim=1)  # 256 *48
224
        c8 = self.conv8(merge8)
225
        up_9 = self.up9(c8)
226
        merge9 = torch.cat([up_9, c1], dim=1)
227
        c9 = self.conv9(merge9)       
228
        c10 = self.conv10(c9)
229
230
231
        return c10
232
233
234
class R2U_Net(nn.Module):
235
    def __init__(self,img_ch=3,output_ch=1,t=2):
236
        super(R2U_Net,self).__init__()
237
        
238
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
239
        self.Upsample = nn.Upsample(scale_factor=2)
240
241
        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
242
243
        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
244
        
245
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
246
        
247
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
248
        
249
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
250
        
251
252
        self.Up5 = up_conv(ch_in=1024,ch_out=512)
253
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
254
        
255
        self.Up4 = up_conv(ch_in=512,ch_out=256)
256
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
257
        
258
        self.Up3 = up_conv(ch_in=256,ch_out=128)
259
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
260
        
261
        self.Up2 = up_conv(ch_in=128,ch_out=64)
262
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
263
264
        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
265
266
267
    def forward(self,x):
268
        # encoding path
269
        x1 = self.RRCNN1(x)
270
271
        x2 = self.Maxpool(x1)
272
        x2 = self.RRCNN2(x2)
273
        
274
        x3 = self.Maxpool(x2)
275
        x3 = self.RRCNN3(x3)
276
277
        x4 = self.Maxpool(x3)
278
        x4 = self.RRCNN4(x4)
279
280
        x5 = self.Maxpool(x4)
281
        x5 = self.RRCNN5(x5)
282
283
        # decoding + concat path
284
        d5 = self.Up5(x5)
285
        d5 = torch.cat((x4,d5),dim=1)
286
        d5 = self.Up_RRCNN5(d5)
287
        
288
        d4 = self.Up4(d5)
289
        d4 = torch.cat((x3,d4),dim=1)
290
        d4 = self.Up_RRCNN4(d4)
291
292
        d3 = self.Up3(d4)
293
        d3 = torch.cat((x2,d3),dim=1)
294
        d3 = self.Up_RRCNN3(d3)
295
296
        d2 = self.Up2(d3)
297
        d2 = torch.cat((x1,d2),dim=1)
298
        d2 = self.Up_RRCNN2(d2)
299
300
        d1 = self.Conv_1x1(d2)
301
302
        return d1
303
304
305
class AttU_Net(nn.Module):
306
    def __init__(self,img_ch=3,output_ch=1):
307
        super(AttU_Net,self).__init__()
308
        
309
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
310
311
        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
312
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
313
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
314
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
315
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)
316
317
        self.Up5 = up_conv(ch_in=1024,ch_out=512)
318
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
319
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
320
321
        self.Up4 = up_conv(ch_in=512,ch_out=256)
322
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
323
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
324
        
325
        self.Up3 = up_conv(ch_in=256,ch_out=128)
326
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
327
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
328
        
329
        self.Up2 = up_conv(ch_in=128,ch_out=64)
330
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
331
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
332
333
        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
334
335
336
    def forward(self,x):
337
        # encoding path
338
        x1 = self.Conv1(x)
339
340
        x2 = self.Maxpool(x1)
341
        x2 = self.Conv2(x2)
342
        
343
        x3 = self.Maxpool(x2)
344
        x3 = self.Conv3(x3)
345
346
        x4 = self.Maxpool(x3)
347
        x4 = self.Conv4(x4)
348
349
        x5 = self.Maxpool(x4)
350
        x5 = self.Conv5(x5)
351
352
        # decoding + concat path
353
        d5 = self.Up5(x5)
354
        x4 = self.Att5(g=d5,x=x4)
355
        d5 = torch.cat((x4,d5),dim=1)        
356
        d5 = self.Up_conv5(d5)
357
        
358
        d4 = self.Up4(d5)
359
        x3 = self.Att4(g=d4,x=x3)
360
        d4 = torch.cat((x3,d4),dim=1)
361
        d4 = self.Up_conv4(d4)
362
363
        d3 = self.Up3(d4)
364
        x2 = self.Att3(g=d3,x=x2)
365
        d3 = torch.cat((x2,d3),dim=1)
366
        d3 = self.Up_conv3(d3)
367
368
        d2 = self.Up2(d3)
369
        x1 = self.Att2(g=d2,x=x1)
370
        d2 = torch.cat((x1,d2),dim=1)
371
        d2 = self.Up_conv2(d2)
372
373
        d1 = self.Conv_1x1(d2)
374
375
        return d1
376
377
378
class R2AttU_Net(nn.Module):
379
    def __init__(self,img_ch=3,output_ch=1,t=2):
380
        super(R2AttU_Net,self).__init__()
381
        
382
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
383
        self.Upsample = nn.Upsample(scale_factor=2)
384
385
        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
386
387
        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
388
        
389
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
390
        
391
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
392
        
393
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
394
        
395
396
        self.Up5 = up_conv(ch_in=1024,ch_out=512)
397
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
398
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
399
        
400
        self.Up4 = up_conv(ch_in=512,ch_out=256)
401
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
402
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
403
        
404
        self.Up3 = up_conv(ch_in=256,ch_out=128)
405
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
406
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
407
        
408
        self.Up2 = up_conv(ch_in=128,ch_out=64)
409
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
410
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
411
412
        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
413
414
415
    def forward(self,x):
416
        # encoding path
417
        x1 = self.RRCNN1(x)
418
419
        x2 = self.Maxpool(x1)
420
        x2 = self.RRCNN2(x2)
421
        
422
        x3 = self.Maxpool(x2)
423
        x3 = self.RRCNN3(x3)
424
425
        x4 = self.Maxpool(x3)
426
        x4 = self.RRCNN4(x4)
427
428
        x5 = self.Maxpool(x4)
429
        x5 = self.RRCNN5(x5)
430
431
        # decoding + concat path
432
        d5 = self.Up5(x5)
433
        x4 = self.Att5(g=d5,x=x4)
434
        d5 = torch.cat((x4,d5),dim=1)
435
        d5 = self.Up_RRCNN5(d5)
436
        
437
        d4 = self.Up4(d5)
438
        x3 = self.Att4(g=d4,x=x3)
439
        d4 = torch.cat((x3,d4),dim=1)
440
        d4 = self.Up_RRCNN4(d4)
441
442
        d3 = self.Up3(d4)
443
        x2 = self.Att3(g=d3,x=x2)
444
        d3 = torch.cat((x2,d3),dim=1)
445
        d3 = self.Up_RRCNN3(d3)
446
447
        d2 = self.Up2(d3)
448
        x1 = self.Att2(g=d2,x=x1)
449
        d2 = torch.cat((x1,d2),dim=1)
450
        d2 = self.Up_RRCNN2(d2)
451
452
        d1 = self.Conv_1x1(d2)
453
454
        return d1