--- a +++ b/CaraNet/CaraNet.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 21 14:58:14 2021 + +@author: angelou +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pretrain.Res2Net_v1b import res2net50_v1b_26w_4s, res2net101_v1b_26w_4s +import math +import torchvision.models as models +from lib.conv_layer import Conv, BNPReLU +from lib.axial_atten import AA_kernel +from lib.context_module import CFPModule +from lib.partial_decoder import aggregation +import os + + + + +class caranet(nn.Module): + def __init__(self, channel=32): + super().__init__() + + # ---- ResNet Backbone ---- + self.resnet = res2net101_v1b_26w_4s(pretrained=True) + + # Receptive Field Block + self.rfb2_1 = Conv(512, 32,3,1,padding=1,bn_acti=True) + self.rfb3_1 = Conv(1024, 32,3,1,padding=1,bn_acti=True) + self.rfb4_1 = Conv(2048, 32,3,1,padding=1,bn_acti=True) + + # Partial Decoder + self.agg1 = aggregation(channel) + + self.CFP_1 = CFPModule(32, d = 8) + self.CFP_2 = CFPModule(32, d = 8) + self.CFP_3 = CFPModule(32, d = 8) + ###### dilation rate 4, 62.8 + + + + self.ra1_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True) + self.ra1_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True) + self.ra1_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True) + + self.ra2_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True) + self.ra2_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True) + self.ra2_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True) + + self.ra3_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True) + self.ra3_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True) + self.ra3_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True) + + self.aa_kernel_1 = AA_kernel(32,32) + self.aa_kernel_2 = AA_kernel(32,32) + self.aa_kernel_3 = AA_kernel(32,32) + + + + + def forward(self, x): + + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x = self.resnet.maxpool(x) # bs, 64, 88, 88 + + # ----------- low-level features ------------- + + x1 = self.resnet.layer1(x) # bs, 256, 88, 88 + x2 = self.resnet.layer2(x1) # bs, 512, 44, 44 + + x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22 + x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11 + + x2_rfb = self.rfb2_1(x2) # 512 - 32 + x3_rfb = self.rfb3_1(x3) # 1024 - 32 + x4_rfb = self.rfb4_1(x4) # 2048 - 32 + + decoder_1 = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 1,44,44 + lateral_map_1 = F.interpolate(decoder_1, scale_factor=8, mode='bilinear') + + # ------------------- atten-one ----------------------- + decoder_2 = F.interpolate(decoder_1, scale_factor=0.25, mode='bilinear') + cfp_out_1 = self.CFP_3(x4_rfb) # 32 - 32 + decoder_2_ra = -1*(torch.sigmoid(decoder_2)) + 1 + aa_atten_3 = self.aa_kernel_3(cfp_out_1) + aa_atten_3_o = decoder_2_ra.expand(-1, 32, -1, -1).mul(aa_atten_3) + + ra_3 = self.ra3_conv1(aa_atten_3_o) # 32 - 32 + ra_3 = self.ra3_conv2(ra_3) # 32 - 32 + ra_3 = self.ra3_conv3(ra_3) # 32 - 1 + + x_3 = ra_3 + decoder_2 + lateral_map_2 = F.interpolate(x_3,scale_factor=32,mode='bilinear') + + # ------------------- atten-two ----------------------- + decoder_3 = F.interpolate(x_3, scale_factor=2, mode='bilinear') + cfp_out_2 = self.CFP_2(x3_rfb) # 32 - 32 + decoder_3_ra = -1*(torch.sigmoid(decoder_3)) + 1 + aa_atten_2 = self.aa_kernel_2(cfp_out_2) + aa_atten_2_o = decoder_3_ra.expand(-1, 32, -1, -1).mul(aa_atten_2) + + ra_2 = self.ra2_conv1(aa_atten_2_o) # 32 - 32 + ra_2 = self.ra2_conv2(ra_2) # 32 - 32 + ra_2 = self.ra2_conv3(ra_2) # 32 - 1 + + x_2 = ra_2 + decoder_3 + lateral_map_3 = F.interpolate(x_2,scale_factor=16,mode='bilinear') + + # ------------------- atten-three ----------------------- + decoder_4 = F.interpolate(x_2, scale_factor=2, mode='bilinear') + cfp_out_3 = self.CFP_1(x2_rfb) # 32 - 32 + decoder_4_ra = -1*(torch.sigmoid(decoder_4)) + 1 + aa_atten_1 = self.aa_kernel_1(cfp_out_3) + aa_atten_1_o = decoder_4_ra.expand(-1, 32, -1, -1).mul(aa_atten_1) + + ra_1 = self.ra1_conv1(aa_atten_1_o) # 32 - 32 + ra_1 = self.ra1_conv2(ra_1) # 32 - 32 + ra_1 = self.ra1_conv3(ra_1) # 32 - 1 + + x_1 = ra_1 + decoder_4 + lateral_map_5 = F.interpolate(x_1,scale_factor=8,mode='bilinear') + + + + return lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1 + + + +if __name__ == '__main__': + ras = caranet().cuda() + input_tensor = torch.randn(1, 3, 352, 352).cuda() + + out = ras(input_tensor) \ No newline at end of file