--- a +++ b/model/MyModel.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from features.feature import * +from gcn.layers import GConv + +class Self_Attn(nn.Module): + """ Self attention Layer""" + def __init__(self,in_dim,activation): + super(Self_Attn,self).__init__() + self.chanel_in = in_dim + self.activation = activation + + self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + def forward(self,x): + m_batchsize,C,width ,height = x.size() + proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) + energy = torch.bmm(proj_query,proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N + + out = torch.bmm(proj_value,attention.permute(0,2,1) ) + out = out.view(m_batchsize,C,width,height) + + + + out = self.gamma*out + x + return out,attention + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + +class conv_block(nn.Module): + def __init__(self,ch_in,ch_out): + super(conv_block,self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True) + ) + + + def forward(self,x): + x = self.conv(x) + return x + +class up_conv(nn.Module): + def __init__(self,ch_in,ch_out): + super(up_conv,self).__init__() + self.up = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True) + ) + + def forward(self,x): + x = self.up(x) + return x + +class Recurrent_block(nn.Module): + def __init__(self,ch_out,t=2): + super(Recurrent_block,self).__init__() + self.t = t + self.ch_out = ch_out + self.conv = nn.Sequential( + nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True) + ) + + def forward(self,x): + for i in range(self.t): + + if i==0: + x1 = self.conv(x) + + x1 = self.conv(x+x1) + return x1 + +class RRCNN_block(nn.Module): + def __init__(self,ch_in,ch_out,t=2): + super(RRCNN_block,self).__init__() + self.RCNN = nn.Sequential( + Recurrent_block(ch_out,t=t), + Recurrent_block(ch_out,t=t) + ) + self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) + + def forward(self,x): + x = self.Conv_1x1(x) + x1 = self.RCNN(x) + return x+x1 + +class single_conv(nn.Module): + def __init__(self,ch_in,ch_out): + super(single_conv,self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True) + ) + + def forward(self,x): + x = self.conv(x) + return x + +class Attention_block(nn.Module): + def __init__(self,F_g,F_l,F_int): + super(Attention_block,self).__init__() + self.W_g = nn.Sequential( + nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + + self.W_x = nn.Sequential( + nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + + self.psi = nn.Sequential( + nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self,g,x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1+x1) + psi = self.psi(psi) + + return x*psi + +def dootsu(img): + se = img + _, c, _, _ = img.size() + # squeezelayer = nn.Conv2d(c, c // 16, kernel_size=1) + # squeezelayer.cuda() + # img = squeezelayer(img) + img = img.cpu().detach() + channel = list(img.size())[1] + batch = list(img.size())[0] + imgfolder = img.chunk(batch, dim=0) + chw_output = [] + for index in range(batch): + bchw = imgfolder[index] + chw = bchw.squeeze() + chwfolder = chw.chunk(channel, dim=0) + hw_output = [] + for i in range(channel): + hw = chwfolder[i].squeeze() + hw = np.transpose(hw.detach().numpy(), (0, 1)) + hw_otsu = otsu(hw) + hw_otsu = torch.from_numpy(hw_otsu) + hw_output.append(hw_otsu) + chw_otsu = torch.stack(hw_output, dim=0) + chw_output.append(chw_otsu) + bchw_otsu = torch.stack(chw_output, dim=0).cuda() + # result = torch.cat([se.float().cuda(), bchw_otsu.float().cuda()],dim=1) + return bchw_otsu + +class Self_Attn_OTSU(nn.Module): + """ Self attention otsu Layer""" + def __init__(self,in_dim,activation): + super(Self_Attn_OTSU,self).__init__() + self.chanel_in = in_dim + self.activation = activation + + self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + def forward(self,x): + m_batchsize,C,width ,height = x.size() + proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) + energy = torch.bmm(proj_query,proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N + + out = torch.bmm(proj_value,attention.permute(0,2,1) ) + out = out.view(m_batchsize,C,width,height) + + out = dootsu(out) + + out = self.gamma*out + x + return out,attention + +def edge_conv2d(im): + in_channel = list(im.size())[1] + out_channel = in_channel + # 用nn.Conv2d定义卷积操作 + conv_op = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False) + '''定义sobel算子参数,所有值除以3----有人觉得出来的图更好些;但我感觉应该是概率问题,没啥用''' + # sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') / 3 + sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') + # 将sobel算子转换为适配卷积操作的卷积核 + sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3)) + # 卷积输出out_channle + sobel_kernel = np.repeat(sobel_kernel, in_channel, axis=1) + # 输入图的通道in_channle + sobel_kernel = np.repeat(sobel_kernel, out_channel, axis=0) + + # 给卷积操作的卷积核赋值 + conv_op.weight.data = torch.from_numpy(sobel_kernel).cuda() + # print(conv_op.weight.size()) + # print(im.size()) + # print(conv_op, '\n') + + edge_detect = conv_op(im) + # print(torch.max(edge_detect)) + # 将输出转换为图片格式 + edge_detect = edge_detect.squeeze().detach() + return edge_detect + +class Self_Attn_EDGE(nn.Module): + """ Self attention edge Layer""" + def __init__(self,in_dim,activation): + super(Self_Attn_EDGE,self).__init__() + self.chanel_in = in_dim + self.activation = activation + + self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + def forward(self,x): + m_batchsize,C,width ,height = x.size() + proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) + energy = torch.bmm(proj_query,proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N + + out = torch.bmm(proj_value,attention.permute(0,2,1) ) + out = out.view(m_batchsize,C,width,height) + + out = edge_conv2d(out) + + out = self.gamma*out + x + return out,attention + +class Self_Attn_GABOR(nn.Module): + """ Self attention gabor Layer""" + def __init__(self,in_dim,activation): + super(Self_Attn_GABOR,self).__init__() + self.chanel_in = in_dim + self.activation = activation + + self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) + self.gaborlayer = GConv(in_dim, out_channels = in_dim//4, kernel_size= 1, stride=1, M=4, nScale=1, bias=False, expand=True) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + def forward(self,x): + m_batchsize,C,width ,height = x.size() + proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) + energy = torch.bmm(proj_query,proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N + + out = torch.bmm(proj_value,attention.permute(0,2,1) ) + out = out.view(m_batchsize,C,width,height) + # print(out.size()) + + out = self.gaborlayer(out) + # print(out.size()) + + out = self.gamma*out + x + return out,attention + +class Mymodel(nn.Module): + def __init__(self,img_ch=3,output_ch=1): + super(Mymodel,self).__init__() + + self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) + + self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) + self.Conv2 = conv_block(ch_in=64,ch_out=128) + self.Conv3 = conv_block(ch_in=128,ch_out=256) + self.Conv4 = conv_block(ch_in=256,ch_out=512) + self.Conv5 = conv_block(ch_in=512,ch_out=1024) + + self.Up5 = up_conv(ch_in=1024,ch_out=512) + self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) + self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) + + self.Up4 = up_conv(ch_in=512,ch_out=256) + self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) + self.Up_conv4 = conv_block(ch_in=512, ch_out=256) + + self.Up3 = up_conv(ch_in=256,ch_out=128) + self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) + self.Up_conv3 = conv_block(ch_in=256, ch_out=128) + + self.Up2 = up_conv(ch_in=128,ch_out=64) + self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) + self.Up_conv2 = conv_block(ch_in=128, ch_out=64) + + # # gabor + # self.gaborlayer = GConv(in_channels=64, out_channels = 16, kernel_size= 1, stride=1, M=4, nScale=1, bias=False, expand=True) + + self.pred1 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0) + self.pred2 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0) + self.pred3 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0) + + # self-attention + self.pred_1 = single_conv(ch_in=64, ch_out=32) + self.pred_2 = single_conv(ch_in=64, ch_out=32) + self.pred_3 = single_conv(ch_in=64, ch_out=32) + self.conv_Atten1 = Self_Attn_OTSU(32,'relu') + self.conv_Atten2 = Self_Attn_GABOR(32,'relu') + # self.conv_Atten3 = Self_Attn(32,'relu') + self.conv_Atten3 = Self_Attn_GABOR(32,'relu') + + # fusion module + self.conv_fusion1 = conv_block(ch_in=160, ch_out=64) + self.conv_fusion2 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) # v1 + + + def forward(self,x): + # encoding path + x1 = self.Conv1(x) + + x2 = self.Maxpool(x1) + x2 = self.Conv2(x2) + + x3 = self.Maxpool(x2) + x3 = self.Conv3(x3) + + x4 = self.Maxpool(x3) + x4 = self.Conv4(x4) + + x5 = self.Maxpool(x4) + x5 = self.Conv5(x5) + + # decoding + concat path + d5 = self.Up5(x5) + x4 = self.Att5(g=d5,x=x4) + d5 = torch.cat((x4,d5),dim=1) + d5 = self.Up_conv5(d5) + + d4 = self.Up4(d5) + x3 = self.Att4(g=d4,x=x3) + d4 = torch.cat((x3,d4),dim=1) + d4 = self.Up_conv4(d4) + + d3 = self.Up3(d4) + x2 = self.Att3(g=d3,x=x2) + d3 = torch.cat((x2,d3),dim=1) + d3 = self.Up_conv3(d3) + + d2 = self.Up2(d3) + x1 = self.Att2(g=d2,x=x1) + d2 = torch.cat((x1,d2),dim=1) + d2 = self.Up_conv2(d2) + + # d2_gabor = self.gaborlayer(d2) + + pred_1 = self.pred_1(d2) + pred_2 = self.pred_2(d2) + pred_3 = self.pred_3(d2) + pred1 = self.pred1(pred_1) + pred2 = self.pred2(pred_2) + pred3 = self.pred3(pred_3) + + # self-attention + attention_higher, _ = self.conv_Atten1(pred_1) + attention_lower, _ = self.conv_Atten2(pred_2) + attention_all, _ = self.conv_Atten3(pred_3) + + # fusion module + y = torch.cat((attention_higher, attention_all, attention_lower, d2), dim=1) + y = self.conv_fusion1(y) + pred = self.conv_fusion2(y) + + return [pred1, pred2, pred3, pred] \ No newline at end of file