|
a |
|
b/CaraNet/CaraNet.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Wed Jul 21 14:58:14 2021 |
|
|
4 |
|
|
|
5 |
@author: angelou |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
import torch.nn as nn |
|
|
10 |
import torch.nn.functional as F |
|
|
11 |
from pretrain.Res2Net_v1b import res2net50_v1b_26w_4s, res2net101_v1b_26w_4s |
|
|
12 |
import math |
|
|
13 |
import torchvision.models as models |
|
|
14 |
from lib.conv_layer import Conv, BNPReLU |
|
|
15 |
from lib.axial_atten import AA_kernel |
|
|
16 |
from lib.context_module import CFPModule |
|
|
17 |
from lib.partial_decoder import aggregation |
|
|
18 |
import os |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
|
|
|
22 |
|
|
|
23 |
class caranet(nn.Module): |
|
|
24 |
def __init__(self, channel=32): |
|
|
25 |
super().__init__() |
|
|
26 |
|
|
|
27 |
# ---- ResNet Backbone ---- |
|
|
28 |
self.resnet = res2net101_v1b_26w_4s(pretrained=True) |
|
|
29 |
|
|
|
30 |
# Receptive Field Block |
|
|
31 |
self.rfb2_1 = Conv(512, 32,3,1,padding=1,bn_acti=True) |
|
|
32 |
self.rfb3_1 = Conv(1024, 32,3,1,padding=1,bn_acti=True) |
|
|
33 |
self.rfb4_1 = Conv(2048, 32,3,1,padding=1,bn_acti=True) |
|
|
34 |
|
|
|
35 |
# Partial Decoder |
|
|
36 |
self.agg1 = aggregation(channel) |
|
|
37 |
|
|
|
38 |
self.CFP_1 = CFPModule(32, d = 8) |
|
|
39 |
self.CFP_2 = CFPModule(32, d = 8) |
|
|
40 |
self.CFP_3 = CFPModule(32, d = 8) |
|
|
41 |
###### dilation rate 4, 62.8 |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
|
|
|
45 |
self.ra1_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True) |
|
|
46 |
self.ra1_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True) |
|
|
47 |
self.ra1_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True) |
|
|
48 |
|
|
|
49 |
self.ra2_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True) |
|
|
50 |
self.ra2_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True) |
|
|
51 |
self.ra2_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True) |
|
|
52 |
|
|
|
53 |
self.ra3_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True) |
|
|
54 |
self.ra3_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True) |
|
|
55 |
self.ra3_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True) |
|
|
56 |
|
|
|
57 |
self.aa_kernel_1 = AA_kernel(32,32) |
|
|
58 |
self.aa_kernel_2 = AA_kernel(32,32) |
|
|
59 |
self.aa_kernel_3 = AA_kernel(32,32) |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def forward(self, x): |
|
|
65 |
|
|
|
66 |
x = self.resnet.conv1(x) |
|
|
67 |
x = self.resnet.bn1(x) |
|
|
68 |
x = self.resnet.relu(x) |
|
|
69 |
x = self.resnet.maxpool(x) # bs, 64, 88, 88 |
|
|
70 |
|
|
|
71 |
# ----------- low-level features ------------- |
|
|
72 |
|
|
|
73 |
x1 = self.resnet.layer1(x) # bs, 256, 88, 88 |
|
|
74 |
x2 = self.resnet.layer2(x1) # bs, 512, 44, 44 |
|
|
75 |
|
|
|
76 |
x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22 |
|
|
77 |
x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11 |
|
|
78 |
|
|
|
79 |
x2_rfb = self.rfb2_1(x2) # 512 - 32 |
|
|
80 |
x3_rfb = self.rfb3_1(x3) # 1024 - 32 |
|
|
81 |
x4_rfb = self.rfb4_1(x4) # 2048 - 32 |
|
|
82 |
|
|
|
83 |
decoder_1 = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 1,44,44 |
|
|
84 |
lateral_map_1 = F.interpolate(decoder_1, scale_factor=8, mode='bilinear') |
|
|
85 |
|
|
|
86 |
# ------------------- atten-one ----------------------- |
|
|
87 |
decoder_2 = F.interpolate(decoder_1, scale_factor=0.25, mode='bilinear') |
|
|
88 |
cfp_out_1 = self.CFP_3(x4_rfb) # 32 - 32 |
|
|
89 |
decoder_2_ra = -1*(torch.sigmoid(decoder_2)) + 1 |
|
|
90 |
aa_atten_3 = self.aa_kernel_3(cfp_out_1) |
|
|
91 |
aa_atten_3_o = decoder_2_ra.expand(-1, 32, -1, -1).mul(aa_atten_3) |
|
|
92 |
|
|
|
93 |
ra_3 = self.ra3_conv1(aa_atten_3_o) # 32 - 32 |
|
|
94 |
ra_3 = self.ra3_conv2(ra_3) # 32 - 32 |
|
|
95 |
ra_3 = self.ra3_conv3(ra_3) # 32 - 1 |
|
|
96 |
|
|
|
97 |
x_3 = ra_3 + decoder_2 |
|
|
98 |
lateral_map_2 = F.interpolate(x_3,scale_factor=32,mode='bilinear') |
|
|
99 |
|
|
|
100 |
# ------------------- atten-two ----------------------- |
|
|
101 |
decoder_3 = F.interpolate(x_3, scale_factor=2, mode='bilinear') |
|
|
102 |
cfp_out_2 = self.CFP_2(x3_rfb) # 32 - 32 |
|
|
103 |
decoder_3_ra = -1*(torch.sigmoid(decoder_3)) + 1 |
|
|
104 |
aa_atten_2 = self.aa_kernel_2(cfp_out_2) |
|
|
105 |
aa_atten_2_o = decoder_3_ra.expand(-1, 32, -1, -1).mul(aa_atten_2) |
|
|
106 |
|
|
|
107 |
ra_2 = self.ra2_conv1(aa_atten_2_o) # 32 - 32 |
|
|
108 |
ra_2 = self.ra2_conv2(ra_2) # 32 - 32 |
|
|
109 |
ra_2 = self.ra2_conv3(ra_2) # 32 - 1 |
|
|
110 |
|
|
|
111 |
x_2 = ra_2 + decoder_3 |
|
|
112 |
lateral_map_3 = F.interpolate(x_2,scale_factor=16,mode='bilinear') |
|
|
113 |
|
|
|
114 |
# ------------------- atten-three ----------------------- |
|
|
115 |
decoder_4 = F.interpolate(x_2, scale_factor=2, mode='bilinear') |
|
|
116 |
cfp_out_3 = self.CFP_1(x2_rfb) # 32 - 32 |
|
|
117 |
decoder_4_ra = -1*(torch.sigmoid(decoder_4)) + 1 |
|
|
118 |
aa_atten_1 = self.aa_kernel_1(cfp_out_3) |
|
|
119 |
aa_atten_1_o = decoder_4_ra.expand(-1, 32, -1, -1).mul(aa_atten_1) |
|
|
120 |
|
|
|
121 |
ra_1 = self.ra1_conv1(aa_atten_1_o) # 32 - 32 |
|
|
122 |
ra_1 = self.ra1_conv2(ra_1) # 32 - 32 |
|
|
123 |
ra_1 = self.ra1_conv3(ra_1) # 32 - 1 |
|
|
124 |
|
|
|
125 |
x_1 = ra_1 + decoder_4 |
|
|
126 |
lateral_map_5 = F.interpolate(x_1,scale_factor=8,mode='bilinear') |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
|
|
|
130 |
return lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1 |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
|
|
|
134 |
if __name__ == '__main__': |
|
|
135 |
ras = caranet().cuda() |
|
|
136 |
input_tensor = torch.randn(1, 3, 352, 352).cuda() |
|
|
137 |
|
|
|
138 |
out = ras(input_tensor) |