[5d8f6c]: / CaraNet / lib / self_attention.py

Download this file

48 lines (36 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 10 17:15:44 2021
@author: angelou
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.conv_layer import Conv
import math
class self_attn(nn.Module):
def __init__(self, in_channels, mode='hw'):
super(self_attn, self).__init__()
self.mode = mode
self.query_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0)
self.key_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0)
self.value_conv = Conv(in_channels, in_channels, kSize=(1, 1),stride=1,padding=0)
self.gamma = nn.Parameter(torch.zeros(1))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
batch_size, channel, height, width = x.size()
axis = 1
if 'h' in self.mode:
axis *= height
if 'w' in self.mode:
axis *= width
view = (batch_size, -1, axis)
projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
projected_key = self.key_conv(x).view(*view)
attention_map = torch.bmm(projected_query, projected_key)
attention = self.sigmoid(attention_map)
projected_value = self.value_conv(x).view(*view)
out = torch.bmm(projected_value, attention.permute(0, 2, 1))
out = out.view(batch_size, channel, height, width)
out = self.gamma * out + x
return out