Diff of /CaraNet/CaraNet.py [000000] .. [6f3ba0]

Switch to unified view

a b/CaraNet/CaraNet.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Wed Jul 21 14:58:14 2021
4
5
@author: angelou
6
"""
7
8
import torch
9
import torch.nn as nn
10
import torch.nn.functional as F
11
from pretrain.Res2Net_v1b import res2net50_v1b_26w_4s, res2net101_v1b_26w_4s
12
import math
13
import torchvision.models as models
14
from lib.conv_layer import Conv, BNPReLU
15
from lib.axial_atten import AA_kernel
16
from lib.context_module import CFPModule
17
from lib.partial_decoder import aggregation
18
import os
19
 
20
    
21
22
    
23
class caranet(nn.Module):
24
    def __init__(self, channel=32):
25
        super().__init__()
26
        
27
         # ---- ResNet Backbone ----
28
        self.resnet = res2net101_v1b_26w_4s(pretrained=True)
29
30
        # Receptive Field Block
31
        self.rfb2_1 = Conv(512, 32,3,1,padding=1,bn_acti=True)
32
        self.rfb3_1 = Conv(1024, 32,3,1,padding=1,bn_acti=True)
33
        self.rfb4_1 = Conv(2048, 32,3,1,padding=1,bn_acti=True)
34
35
        # Partial Decoder
36
        self.agg1 = aggregation(channel)
37
        
38
        self.CFP_1 = CFPModule(32, d = 8)
39
        self.CFP_2 = CFPModule(32, d = 8)
40
        self.CFP_3 = CFPModule(32, d = 8)
41
        ###### dilation rate 4, 62.8
42
43
44
45
        self.ra1_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
46
        self.ra1_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
47
        self.ra1_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
48
        
49
        self.ra2_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
50
        self.ra2_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
51
        self.ra2_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
52
        
53
        self.ra3_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
54
        self.ra3_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
55
        self.ra3_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
56
        
57
        self.aa_kernel_1 = AA_kernel(32,32)
58
        self.aa_kernel_2 = AA_kernel(32,32)
59
        self.aa_kernel_3 = AA_kernel(32,32)
60
        
61
62
63
64
    def forward(self, x):
65
        
66
        x = self.resnet.conv1(x)
67
        x = self.resnet.bn1(x)
68
        x = self.resnet.relu(x)
69
        x = self.resnet.maxpool(x)      # bs, 64, 88, 88
70
        
71
        # ----------- low-level features -------------
72
        
73
        x1 = self.resnet.layer1(x)      # bs, 256, 88, 88
74
        x2 = self.resnet.layer2(x1)     # bs, 512, 44, 44
75
        
76
        x3 = self.resnet.layer3(x2)     # bs, 1024, 22, 22
77
        x4 = self.resnet.layer4(x3)     # bs, 2048, 11, 11
78
        
79
        x2_rfb = self.rfb2_1(x2) # 512 - 32
80
        x3_rfb = self.rfb3_1(x3) # 1024 - 32
81
        x4_rfb = self.rfb4_1(x4) # 2048 - 32
82
        
83
        decoder_1 = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 1,44,44
84
        lateral_map_1 = F.interpolate(decoder_1, scale_factor=8, mode='bilinear')
85
        
86
        # ------------------- atten-one -----------------------
87
        decoder_2 = F.interpolate(decoder_1, scale_factor=0.25, mode='bilinear')
88
        cfp_out_1 = self.CFP_3(x4_rfb) # 32 - 32
89
        decoder_2_ra = -1*(torch.sigmoid(decoder_2)) + 1
90
        aa_atten_3 = self.aa_kernel_3(cfp_out_1)
91
        aa_atten_3_o = decoder_2_ra.expand(-1, 32, -1, -1).mul(aa_atten_3)
92
        
93
        ra_3 = self.ra3_conv1(aa_atten_3_o) # 32 - 32
94
        ra_3 = self.ra3_conv2(ra_3) # 32 - 32
95
        ra_3 = self.ra3_conv3(ra_3) # 32 - 1
96
        
97
        x_3 = ra_3 + decoder_2
98
        lateral_map_2 = F.interpolate(x_3,scale_factor=32,mode='bilinear')
99
        
100
        # ------------------- atten-two -----------------------      
101
        decoder_3 = F.interpolate(x_3, scale_factor=2, mode='bilinear')
102
        cfp_out_2 = self.CFP_2(x3_rfb) # 32 - 32
103
        decoder_3_ra = -1*(torch.sigmoid(decoder_3)) + 1
104
        aa_atten_2 = self.aa_kernel_2(cfp_out_2)
105
        aa_atten_2_o = decoder_3_ra.expand(-1, 32, -1, -1).mul(aa_atten_2)
106
        
107
        ra_2 = self.ra2_conv1(aa_atten_2_o) # 32 - 32
108
        ra_2 = self.ra2_conv2(ra_2) # 32 - 32
109
        ra_2 = self.ra2_conv3(ra_2) # 32 - 1
110
        
111
        x_2 = ra_2 + decoder_3
112
        lateral_map_3 = F.interpolate(x_2,scale_factor=16,mode='bilinear')        
113
        
114
        # ------------------- atten-three -----------------------
115
        decoder_4 = F.interpolate(x_2, scale_factor=2, mode='bilinear')
116
        cfp_out_3 = self.CFP_1(x2_rfb) # 32 - 32
117
        decoder_4_ra = -1*(torch.sigmoid(decoder_4)) + 1
118
        aa_atten_1 = self.aa_kernel_1(cfp_out_3)
119
        aa_atten_1_o = decoder_4_ra.expand(-1, 32, -1, -1).mul(aa_atten_1)
120
        
121
        ra_1 = self.ra1_conv1(aa_atten_1_o) # 32 - 32
122
        ra_1 = self.ra1_conv2(ra_1) # 32 - 32
123
        ra_1 = self.ra1_conv3(ra_1) # 32 - 1
124
        
125
        x_1 = ra_1 + decoder_4
126
        lateral_map_5 = F.interpolate(x_1,scale_factor=8,mode='bilinear') 
127
        
128
129
130
        return lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1
131
    
132
133
134
if __name__ == '__main__':
135
    ras = caranet().cuda()
136
    input_tensor = torch.randn(1, 3, 352, 352).cuda()
137
138
    out = ras(input_tensor)