--- a +++ b/CaraNet/lib/self_attention.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 10 17:15:44 2021 + +@author: angelou +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from lib.conv_layer import Conv +import math + +class self_attn(nn.Module): + def __init__(self, in_channels, mode='hw'): + super(self_attn, self).__init__() + + self.mode = mode + + self.query_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0) + self.key_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0) + self.value_conv = Conv(in_channels, in_channels, kSize=(1, 1),stride=1,padding=0) + + self.gamma = nn.Parameter(torch.zeros(1)) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + batch_size, channel, height, width = x.size() + + axis = 1 + if 'h' in self.mode: + axis *= height + if 'w' in self.mode: + axis *= width + + view = (batch_size, -1, axis) + + projected_query = self.query_conv(x).view(*view).permute(0, 2, 1) + projected_key = self.key_conv(x).view(*view) + + attention_map = torch.bmm(projected_query, projected_key) + attention = self.sigmoid(attention_map) + projected_value = self.value_conv(x).view(*view) + + out = torch.bmm(projected_value, attention.permute(0, 2, 1)) + out = out.view(batch_size, channel, height, width) + + out = self.gamma * out + x + return out \ No newline at end of file