Diff of /model/MyModel.py [000000] .. [f77492]

Switch to unified view

a b/model/MyModel.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from torch.nn import init
5
from features.feature import *
6
from gcn.layers import GConv
7
8
class Self_Attn(nn.Module):
9
    """ Self attention Layer"""
10
    def __init__(self,in_dim,activation):
11
        super(Self_Attn,self).__init__()
12
        self.chanel_in = in_dim
13
        self.activation = activation
14
        
15
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
16
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
17
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
18
        self.gamma = nn.Parameter(torch.zeros(1))
19
20
        self.softmax  = nn.Softmax(dim=-1) #
21
    def forward(self,x):
22
        m_batchsize,C,width ,height = x.size()
23
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
24
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
25
        energy =  torch.bmm(proj_query,proj_key) # transpose check
26
        attention = self.softmax(energy) # BX (N) X (N) 
27
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
28
29
        out = torch.bmm(proj_value,attention.permute(0,2,1) )
30
        out = out.view(m_batchsize,C,width,height)
31
32
33
        
34
        out = self.gamma*out + x
35
        return out,attention
36
37
def init_weights(net, init_type='normal', gain=0.02):
38
    def init_func(m):
39
        classname = m.__class__.__name__
40
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
41
            if init_type == 'normal':
42
                init.normal_(m.weight.data, 0.0, gain)
43
            elif init_type == 'xavier':
44
                init.xavier_normal_(m.weight.data, gain=gain)
45
            elif init_type == 'kaiming':
46
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
47
            elif init_type == 'orthogonal':
48
                init.orthogonal_(m.weight.data, gain=gain)
49
            else:
50
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51
            if hasattr(m, 'bias') and m.bias is not None:
52
                init.constant_(m.bias.data, 0.0)
53
        elif classname.find('BatchNorm2d') != -1:
54
            init.normal_(m.weight.data, 1.0, gain)
55
            init.constant_(m.bias.data, 0.0)
56
57
    print('initialize network with %s' % init_type)
58
    net.apply(init_func)
59
60
class conv_block(nn.Module):
61
    def __init__(self,ch_in,ch_out):
62
        super(conv_block,self).__init__()
63
        self.conv = nn.Sequential(
64
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
65
            nn.BatchNorm2d(ch_out),
66
            nn.ReLU(inplace=True),
67
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
68
            nn.BatchNorm2d(ch_out),
69
            nn.ReLU(inplace=True)
70
        )
71
72
73
    def forward(self,x):
74
        x = self.conv(x)
75
        return x
76
77
class up_conv(nn.Module):
78
    def __init__(self,ch_in,ch_out):
79
        super(up_conv,self).__init__()
80
        self.up = nn.Sequential(
81
            nn.Upsample(scale_factor=2),
82
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
83
            nn.BatchNorm2d(ch_out),
84
            nn.ReLU(inplace=True)
85
        )
86
87
    def forward(self,x):
88
        x = self.up(x)
89
        return x
90
91
class Recurrent_block(nn.Module):
92
    def __init__(self,ch_out,t=2):
93
        super(Recurrent_block,self).__init__()
94
        self.t = t
95
        self.ch_out = ch_out
96
        self.conv = nn.Sequential(
97
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
98
            nn.BatchNorm2d(ch_out),
99
            nn.ReLU(inplace=True)
100
        )
101
102
    def forward(self,x):
103
        for i in range(self.t):
104
105
            if i==0:
106
                x1 = self.conv(x)
107
            
108
            x1 = self.conv(x+x1)
109
        return x1
110
        
111
class RRCNN_block(nn.Module):
112
    def __init__(self,ch_in,ch_out,t=2):
113
        super(RRCNN_block,self).__init__()
114
        self.RCNN = nn.Sequential(
115
            Recurrent_block(ch_out,t=t),
116
            Recurrent_block(ch_out,t=t)
117
        )
118
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
119
120
    def forward(self,x):
121
        x = self.Conv_1x1(x)
122
        x1 = self.RCNN(x)
123
        return x+x1
124
125
class single_conv(nn.Module):
126
    def __init__(self,ch_in,ch_out):
127
        super(single_conv,self).__init__()
128
        self.conv = nn.Sequential(
129
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
130
            nn.BatchNorm2d(ch_out),
131
            nn.ReLU(inplace=True)
132
        )
133
134
    def forward(self,x):
135
        x = self.conv(x)
136
        return x
137
138
class Attention_block(nn.Module):
139
    def __init__(self,F_g,F_l,F_int):
140
        super(Attention_block,self).__init__()
141
        self.W_g = nn.Sequential(
142
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
143
            nn.BatchNorm2d(F_int)
144
            )
145
        
146
        self.W_x = nn.Sequential(
147
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
148
            nn.BatchNorm2d(F_int)
149
        )
150
151
        self.psi = nn.Sequential(
152
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
153
            nn.BatchNorm2d(1),
154
            nn.Sigmoid()
155
        )
156
        
157
        self.relu = nn.ReLU(inplace=True)
158
        
159
    def forward(self,g,x):
160
        g1 = self.W_g(g)
161
        x1 = self.W_x(x)
162
        psi = self.relu(g1+x1)
163
        psi = self.psi(psi)
164
165
        return x*psi
166
167
def dootsu(img):
168
    se = img
169
    _, c, _, _ = img.size()
170
    # squeezelayer = nn.Conv2d(c, c // 16, kernel_size=1)
171
    # squeezelayer.cuda()
172
    # img = squeezelayer(img)
173
    img = img.cpu().detach()
174
    channel = list(img.size())[1]
175
    batch = list(img.size())[0]
176
    imgfolder = img.chunk(batch, dim=0)
177
    chw_output = []
178
    for index in range(batch):
179
        bchw = imgfolder[index]
180
        chw = bchw.squeeze()
181
        chwfolder = chw.chunk(channel, dim=0)
182
        hw_output = []
183
        for i in range(channel):
184
            hw = chwfolder[i].squeeze()
185
            hw = np.transpose(hw.detach().numpy(), (0, 1))
186
            hw_otsu = otsu(hw)
187
            hw_otsu = torch.from_numpy(hw_otsu)
188
            hw_output.append(hw_otsu)
189
        chw_otsu = torch.stack(hw_output, dim=0)
190
        chw_output.append(chw_otsu)
191
    bchw_otsu = torch.stack(chw_output, dim=0).cuda()
192
    # result = torch.cat([se.float().cuda(), bchw_otsu.float().cuda()],dim=1)
193
    return bchw_otsu
194
195
class Self_Attn_OTSU(nn.Module):
196
    """ Self attention otsu Layer"""
197
    def __init__(self,in_dim,activation):
198
        super(Self_Attn_OTSU,self).__init__()
199
        self.chanel_in = in_dim
200
        self.activation = activation
201
        
202
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
203
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
204
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
205
        self.gamma = nn.Parameter(torch.zeros(1))
206
207
        self.softmax  = nn.Softmax(dim=-1) #
208
    def forward(self,x):
209
        m_batchsize,C,width ,height = x.size()
210
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
211
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
212
        energy =  torch.bmm(proj_query,proj_key) # transpose check
213
        attention = self.softmax(energy) # BX (N) X (N) 
214
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
215
216
        out = torch.bmm(proj_value,attention.permute(0,2,1) )
217
        out = out.view(m_batchsize,C,width,height)
218
219
        out = dootsu(out)
220
        
221
        out = self.gamma*out + x
222
        return out,attention
223
224
def edge_conv2d(im):
225
    in_channel = list(im.size())[1]
226
    out_channel = in_channel
227
    # 用nn.Conv2d定义卷积操作
228
    conv_op = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False)
229
    '''定义sobel算子参数,所有值除以3----有人觉得出来的图更好些;但我感觉应该是概率问题,没啥用'''
230
    # sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') / 3
231
    sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32')
232
    # 将sobel算子转换为适配卷积操作的卷积核
233
    sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3))
234
    # 卷积输出out_channle
235
    sobel_kernel = np.repeat(sobel_kernel, in_channel, axis=1)
236
    # 输入图的通道in_channle
237
    sobel_kernel = np.repeat(sobel_kernel, out_channel, axis=0)
238
    
239
    # 给卷积操作的卷积核赋值
240
    conv_op.weight.data = torch.from_numpy(sobel_kernel).cuda()
241
    # print(conv_op.weight.size())
242
    # print(im.size())
243
    # print(conv_op, '\n')
244
245
    edge_detect = conv_op(im)
246
    # print(torch.max(edge_detect))
247
    # 将输出转换为图片格式
248
    edge_detect = edge_detect.squeeze().detach()
249
    return edge_detect
250
251
class Self_Attn_EDGE(nn.Module):
252
    """ Self attention edge Layer"""
253
    def __init__(self,in_dim,activation):
254
        super(Self_Attn_EDGE,self).__init__()
255
        self.chanel_in = in_dim
256
        self.activation = activation
257
        
258
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
259
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
260
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
261
        self.gamma = nn.Parameter(torch.zeros(1))
262
263
        self.softmax  = nn.Softmax(dim=-1) #
264
    def forward(self,x):
265
        m_batchsize,C,width ,height = x.size()
266
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
267
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
268
        energy =  torch.bmm(proj_query,proj_key) # transpose check
269
        attention = self.softmax(energy) # BX (N) X (N) 
270
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
271
272
        out = torch.bmm(proj_value,attention.permute(0,2,1) )
273
        out = out.view(m_batchsize,C,width,height)
274
275
        out = edge_conv2d(out)
276
        
277
        out = self.gamma*out + x
278
        return out,attention
279
280
class Self_Attn_GABOR(nn.Module):
281
    """ Self attention gabor Layer"""
282
    def __init__(self,in_dim,activation):
283
        super(Self_Attn_GABOR,self).__init__()
284
        self.chanel_in = in_dim
285
        self.activation = activation
286
        
287
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
288
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
289
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
290
        self.gaborlayer = GConv(in_dim, out_channels = in_dim//4, kernel_size= 1, stride=1, M=4, nScale=1, bias=False, expand=True)
291
        self.gamma = nn.Parameter(torch.zeros(1))
292
293
        self.softmax  = nn.Softmax(dim=-1) #
294
    def forward(self,x):
295
        m_batchsize,C,width ,height = x.size()
296
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
297
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
298
        energy =  torch.bmm(proj_query,proj_key) # transpose check
299
        attention = self.softmax(energy) # BX (N) X (N) 
300
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
301
302
        out = torch.bmm(proj_value,attention.permute(0,2,1) )
303
        out = out.view(m_batchsize,C,width,height)
304
        # print(out.size())
305
306
        out = self.gaborlayer(out)
307
        # print(out.size())
308
        
309
        out = self.gamma*out + x
310
        return out,attention
311
312
class Mymodel(nn.Module):
313
    def __init__(self,img_ch=3,output_ch=1):
314
        super(Mymodel,self).__init__()
315
        
316
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
317
318
        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
319
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
320
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
321
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
322
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)
323
324
        self.Up5 = up_conv(ch_in=1024,ch_out=512)
325
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
326
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
327
328
        self.Up4 = up_conv(ch_in=512,ch_out=256)
329
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
330
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
331
        
332
        self.Up3 = up_conv(ch_in=256,ch_out=128)
333
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
334
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
335
        
336
        self.Up2 = up_conv(ch_in=128,ch_out=64)
337
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
338
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
339
        
340
        # # gabor
341
        # self.gaborlayer = GConv(in_channels=64, out_channels = 16, kernel_size= 1, stride=1, M=4, nScale=1, bias=False, expand=True)
342
343
        self.pred1 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0)
344
        self.pred2 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0)
345
        self.pred3 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0)
346
347
        # self-attention
348
        self.pred_1 = single_conv(ch_in=64, ch_out=32)
349
        self.pred_2 = single_conv(ch_in=64, ch_out=32)
350
        self.pred_3 = single_conv(ch_in=64, ch_out=32)
351
        self.conv_Atten1 = Self_Attn_OTSU(32,'relu')
352
        self.conv_Atten2 = Self_Attn_GABOR(32,'relu')
353
        # self.conv_Atten3 = Self_Attn(32,'relu')
354
        self.conv_Atten3 = Self_Attn_GABOR(32,'relu')
355
        
356
        # fusion module
357
        self.conv_fusion1 = conv_block(ch_in=160, ch_out=64)
358
        self.conv_fusion2 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)   # v1
359
360
361
    def forward(self,x):
362
        # encoding path
363
        x1 = self.Conv1(x)
364
365
        x2 = self.Maxpool(x1)
366
        x2 = self.Conv2(x2)
367
        
368
        x3 = self.Maxpool(x2)
369
        x3 = self.Conv3(x3)
370
371
        x4 = self.Maxpool(x3)
372
        x4 = self.Conv4(x4)
373
374
        x5 = self.Maxpool(x4)
375
        x5 = self.Conv5(x5)
376
377
        # decoding + concat path
378
        d5 = self.Up5(x5)
379
        x4 = self.Att5(g=d5,x=x4)
380
        d5 = torch.cat((x4,d5),dim=1)        
381
        d5 = self.Up_conv5(d5)
382
        
383
        d4 = self.Up4(d5)
384
        x3 = self.Att4(g=d4,x=x3)
385
        d4 = torch.cat((x3,d4),dim=1)
386
        d4 = self.Up_conv4(d4)
387
388
        d3 = self.Up3(d4)
389
        x2 = self.Att3(g=d3,x=x2)
390
        d3 = torch.cat((x2,d3),dim=1)
391
        d3 = self.Up_conv3(d3)
392
393
        d2 = self.Up2(d3)
394
        x1 = self.Att2(g=d2,x=x1)
395
        d2 = torch.cat((x1,d2),dim=1)
396
        d2 = self.Up_conv2(d2)
397
398
        # d2_gabor = self.gaborlayer(d2)
399
400
        pred_1 = self.pred_1(d2)
401
        pred_2 = self.pred_2(d2)
402
        pred_3 = self.pred_3(d2)
403
        pred1 = self.pred1(pred_1)
404
        pred2 = self.pred2(pred_2)
405
        pred3 = self.pred3(pred_3)
406
407
        # self-attention
408
        attention_higher, _ = self.conv_Atten1(pred_1)
409
        attention_lower, _ = self.conv_Atten2(pred_2)
410
        attention_all, _ = self.conv_Atten3(pred_3)
411
412
        # fusion module
413
        y = torch.cat((attention_higher, attention_all, attention_lower, d2), dim=1)
414
        y = self.conv_fusion1(y)
415
        pred = self.conv_fusion2(y)
416
417
        return [pred1, pred2, pred3, pred]