Diff of /CaraNet/CaraNet.py [000000] .. [6f3ba0]

Switch to side-by-side view

--- a
+++ b/CaraNet/CaraNet.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Jul 21 14:58:14 2021
+
+@author: angelou
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pretrain.Res2Net_v1b import res2net50_v1b_26w_4s, res2net101_v1b_26w_4s
+import math
+import torchvision.models as models
+from lib.conv_layer import Conv, BNPReLU
+from lib.axial_atten import AA_kernel
+from lib.context_module import CFPModule
+from lib.partial_decoder import aggregation
+import os
+ 
+    
+
+    
+class caranet(nn.Module):
+    def __init__(self, channel=32):
+        super().__init__()
+        
+         # ---- ResNet Backbone ----
+        self.resnet = res2net101_v1b_26w_4s(pretrained=True)
+
+        # Receptive Field Block
+        self.rfb2_1 = Conv(512, 32,3,1,padding=1,bn_acti=True)
+        self.rfb3_1 = Conv(1024, 32,3,1,padding=1,bn_acti=True)
+        self.rfb4_1 = Conv(2048, 32,3,1,padding=1,bn_acti=True)
+
+        # Partial Decoder
+        self.agg1 = aggregation(channel)
+        
+        self.CFP_1 = CFPModule(32, d = 8)
+        self.CFP_2 = CFPModule(32, d = 8)
+        self.CFP_3 = CFPModule(32, d = 8)
+        ###### dilation rate 4, 62.8
+
+
+
+        self.ra1_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
+        self.ra1_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
+        self.ra1_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
+        
+        self.ra2_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
+        self.ra2_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
+        self.ra2_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
+        
+        self.ra3_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
+        self.ra3_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
+        self.ra3_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
+        
+        self.aa_kernel_1 = AA_kernel(32,32)
+        self.aa_kernel_2 = AA_kernel(32,32)
+        self.aa_kernel_3 = AA_kernel(32,32)
+        
+
+
+
+    def forward(self, x):
+        
+        x = self.resnet.conv1(x)
+        x = self.resnet.bn1(x)
+        x = self.resnet.relu(x)
+        x = self.resnet.maxpool(x)      # bs, 64, 88, 88
+        
+        # ----------- low-level features -------------
+        
+        x1 = self.resnet.layer1(x)      # bs, 256, 88, 88
+        x2 = self.resnet.layer2(x1)     # bs, 512, 44, 44
+        
+        x3 = self.resnet.layer3(x2)     # bs, 1024, 22, 22
+        x4 = self.resnet.layer4(x3)     # bs, 2048, 11, 11
+        
+        x2_rfb = self.rfb2_1(x2) # 512 - 32
+        x3_rfb = self.rfb3_1(x3) # 1024 - 32
+        x4_rfb = self.rfb4_1(x4) # 2048 - 32
+        
+        decoder_1 = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 1,44,44
+        lateral_map_1 = F.interpolate(decoder_1, scale_factor=8, mode='bilinear')
+        
+        # ------------------- atten-one -----------------------
+        decoder_2 = F.interpolate(decoder_1, scale_factor=0.25, mode='bilinear')
+        cfp_out_1 = self.CFP_3(x4_rfb) # 32 - 32
+        decoder_2_ra = -1*(torch.sigmoid(decoder_2)) + 1
+        aa_atten_3 = self.aa_kernel_3(cfp_out_1)
+        aa_atten_3_o = decoder_2_ra.expand(-1, 32, -1, -1).mul(aa_atten_3)
+        
+        ra_3 = self.ra3_conv1(aa_atten_3_o) # 32 - 32
+        ra_3 = self.ra3_conv2(ra_3) # 32 - 32
+        ra_3 = self.ra3_conv3(ra_3) # 32 - 1
+        
+        x_3 = ra_3 + decoder_2
+        lateral_map_2 = F.interpolate(x_3,scale_factor=32,mode='bilinear')
+        
+        # ------------------- atten-two -----------------------      
+        decoder_3 = F.interpolate(x_3, scale_factor=2, mode='bilinear')
+        cfp_out_2 = self.CFP_2(x3_rfb) # 32 - 32
+        decoder_3_ra = -1*(torch.sigmoid(decoder_3)) + 1
+        aa_atten_2 = self.aa_kernel_2(cfp_out_2)
+        aa_atten_2_o = decoder_3_ra.expand(-1, 32, -1, -1).mul(aa_atten_2)
+        
+        ra_2 = self.ra2_conv1(aa_atten_2_o) # 32 - 32
+        ra_2 = self.ra2_conv2(ra_2) # 32 - 32
+        ra_2 = self.ra2_conv3(ra_2) # 32 - 1
+        
+        x_2 = ra_2 + decoder_3
+        lateral_map_3 = F.interpolate(x_2,scale_factor=16,mode='bilinear')        
+        
+        # ------------------- atten-three -----------------------
+        decoder_4 = F.interpolate(x_2, scale_factor=2, mode='bilinear')
+        cfp_out_3 = self.CFP_1(x2_rfb) # 32 - 32
+        decoder_4_ra = -1*(torch.sigmoid(decoder_4)) + 1
+        aa_atten_1 = self.aa_kernel_1(cfp_out_3)
+        aa_atten_1_o = decoder_4_ra.expand(-1, 32, -1, -1).mul(aa_atten_1)
+        
+        ra_1 = self.ra1_conv1(aa_atten_1_o) # 32 - 32
+        ra_1 = self.ra1_conv2(ra_1) # 32 - 32
+        ra_1 = self.ra1_conv3(ra_1) # 32 - 1
+        
+        x_1 = ra_1 + decoder_4
+        lateral_map_5 = F.interpolate(x_1,scale_factor=8,mode='bilinear') 
+        
+
+
+        return lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1
+    
+
+
+if __name__ == '__main__':
+    ras = caranet().cuda()
+    input_tensor = torch.randn(1, 3, 352, 352).cuda()
+
+    out = ras(input_tensor)
\ No newline at end of file