import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from features.feature import *
from gcn.layers import GConv
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = self.gamma*out + x
return out,attention
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(up_conv,self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.up(x)
return x
class Recurrent_block(nn.Module):
def __init__(self,ch_out,t=2):
super(Recurrent_block,self).__init__()
self.t = t
self.ch_out = ch_out
self.conv = nn.Sequential(
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
for i in range(self.t):
if i==0:
x1 = self.conv(x)
x1 = self.conv(x+x1)
return x1
class RRCNN_block(nn.Module):
def __init__(self,ch_in,ch_out,t=2):
super(RRCNN_block,self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(ch_out,t=t),
Recurrent_block(ch_out,t=t)
)
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
def forward(self,x):
x = self.Conv_1x1(x)
x1 = self.RCNN(x)
return x+x1
class single_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(single_conv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class Attention_block(nn.Module):
def __init__(self,F_g,F_l,F_int):
super(Attention_block,self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self,g,x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1+x1)
psi = self.psi(psi)
return x*psi
def dootsu(img):
se = img
_, c, _, _ = img.size()
# squeezelayer = nn.Conv2d(c, c // 16, kernel_size=1)
# squeezelayer.cuda()
# img = squeezelayer(img)
img = img.cpu().detach()
channel = list(img.size())[1]
batch = list(img.size())[0]
imgfolder = img.chunk(batch, dim=0)
chw_output = []
for index in range(batch):
bchw = imgfolder[index]
chw = bchw.squeeze()
chwfolder = chw.chunk(channel, dim=0)
hw_output = []
for i in range(channel):
hw = chwfolder[i].squeeze()
hw = np.transpose(hw.detach().numpy(), (0, 1))
hw_otsu = otsu(hw)
hw_otsu = torch.from_numpy(hw_otsu)
hw_output.append(hw_otsu)
chw_otsu = torch.stack(hw_output, dim=0)
chw_output.append(chw_otsu)
bchw_otsu = torch.stack(chw_output, dim=0).cuda()
# result = torch.cat([se.float().cuda(), bchw_otsu.float().cuda()],dim=1)
return bchw_otsu
class Self_Attn_OTSU(nn.Module):
""" Self attention otsu Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn_OTSU,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = dootsu(out)
out = self.gamma*out + x
return out,attention
def edge_conv2d(im):
in_channel = list(im.size())[1]
out_channel = in_channel
# 用nn.Conv2d定义卷积操作
conv_op = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False)
'''定义sobel算子参数,所有值除以3----有人觉得出来的图更好些;但我感觉应该是概率问题,没啥用'''
# sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') / 3
sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32')
# 将sobel算子转换为适配卷积操作的卷积核
sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3))
# 卷积输出out_channle
sobel_kernel = np.repeat(sobel_kernel, in_channel, axis=1)
# 输入图的通道in_channle
sobel_kernel = np.repeat(sobel_kernel, out_channel, axis=0)
# 给卷积操作的卷积核赋值
conv_op.weight.data = torch.from_numpy(sobel_kernel).cuda()
# print(conv_op.weight.size())
# print(im.size())
# print(conv_op, '\n')
edge_detect = conv_op(im)
# print(torch.max(edge_detect))
# 将输出转换为图片格式
edge_detect = edge_detect.squeeze().detach()
return edge_detect
class Self_Attn_EDGE(nn.Module):
""" Self attention edge Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn_EDGE,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = edge_conv2d(out)
out = self.gamma*out + x
return out,attention
class Self_Attn_GABOR(nn.Module):
""" Self attention gabor Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn_GABOR,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gaborlayer = GConv(in_dim, out_channels = in_dim//4, kernel_size= 1, stride=1, M=4, nScale=1, bias=False, expand=True)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
# print(out.size())
out = self.gaborlayer(out)
# print(out.size())
out = self.gamma*out + x
return out,attention
class Mymodel(nn.Module):
def __init__(self,img_ch=3,output_ch=1):
super(Mymodel,self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
self.Conv2 = conv_block(ch_in=64,ch_out=128)
self.Conv3 = conv_block(ch_in=128,ch_out=256)
self.Conv4 = conv_block(ch_in=256,ch_out=512)
self.Conv5 = conv_block(ch_in=512,ch_out=1024)
self.Up5 = up_conv(ch_in=1024,ch_out=512)
self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Up4 = up_conv(ch_in=512,ch_out=256)
self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Up3 = up_conv(ch_in=256,ch_out=128)
self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Up2 = up_conv(ch_in=128,ch_out=64)
self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
# # gabor
# self.gaborlayer = GConv(in_channels=64, out_channels = 16, kernel_size= 1, stride=1, M=4, nScale=1, bias=False, expand=True)
self.pred1 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0)
self.pred2 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0)
self.pred3 = nn.Conv2d(32,output_ch,kernel_size=1,stride=1,padding=0)
# self-attention
self.pred_1 = single_conv(ch_in=64, ch_out=32)
self.pred_2 = single_conv(ch_in=64, ch_out=32)
self.pred_3 = single_conv(ch_in=64, ch_out=32)
self.conv_Atten1 = Self_Attn_OTSU(32,'relu')
self.conv_Atten2 = Self_Attn_GABOR(32,'relu')
# self.conv_Atten3 = Self_Attn(32,'relu')
self.conv_Atten3 = Self_Attn_GABOR(32,'relu')
# fusion module
self.conv_fusion1 = conv_block(ch_in=160, ch_out=64)
self.conv_fusion2 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) # v1
def forward(self,x):
# encoding path
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)
# decoding + concat path
d5 = self.Up5(x5)
x4 = self.Att5(g=d5,x=x4)
d5 = torch.cat((x4,d5),dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4,x=x3)
d4 = torch.cat((x3,d4),dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3,x=x2)
d3 = torch.cat((x2,d3),dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2,x=x1)
d2 = torch.cat((x1,d2),dim=1)
d2 = self.Up_conv2(d2)
# d2_gabor = self.gaborlayer(d2)
pred_1 = self.pred_1(d2)
pred_2 = self.pred_2(d2)
pred_3 = self.pred_3(d2)
pred1 = self.pred1(pred_1)
pred2 = self.pred2(pred_2)
pred3 = self.pred3(pred_3)
# self-attention
attention_higher, _ = self.conv_Atten1(pred_1)
attention_lower, _ = self.conv_Atten2(pred_2)
attention_all, _ = self.conv_Atten3(pred_3)
# fusion module
y = torch.cat((attention_higher, attention_all, attention_lower, d2), dim=1)
y = self.conv_fusion1(y)
pred = self.conv_fusion2(y)
return [pred1, pred2, pred3, pred]