[92a2c8]: / attention.py

Download this file

120 lines (95 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
class PAM_Module(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = Conv1d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = Conv1d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = Conv1d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = torch.nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
m_batchsize, C, height = x.size()
proj_query = self.query_conv(x).permute(0, 2, 1)
proj_key = self.key_conv(x)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = self.gamma*out + x
return out
class CAM_Module(nn.Module):
""" Channel attention module"""
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = torch.nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height = x.size()
proj_query = x
proj_key = x.permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
attention = self.softmax(energy_new)
proj_value = x
out = torch.bmm(attention, proj_value)
out = self.gamma*out + x
return out
class SAModule(nn.Module):
"""
Re-implementation of spatial attention module (SAM) described in:
*Liu et al., Dual Attention Network for Scene Segmentation, cvpr2019
code reference:
https://github.com/junfu1115/DANet/blob/master/encoding/nn/attention.py
"""
def __init__(self, num_channels):
super(SAModule, self).__init__()
self.num_channels = num_channels
self.cam = CAM_Module(num_channels)
self.pam = PAM_Module(num_channels)
def forward(self, feat_map):
feat_map_cam = self.cam(feat_map)
feat_map_pam = self.pam(feat_map)
feat_map = (feat_map_cam+feat_map_pam)/2
return feat_map
class CAModule(nn.Module):
"""##Squeeze and excite CAM
Re-implementation of Squeeze-and-Excitation (SE) block described in:
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
code reference:
https://github.com/kobiso/CBAM-keras/blob/master/models/attention_module.py
"""
def __init__(self, num_channels, reduc_ratio=2):
super(CAModule, self).__init__()
self.num_channels = num_channels
self.reduc_ratio = reduc_ratio
self.fc1 = nn.Linear(num_channels, num_channels // reduc_ratio,
bias=True)
self.fc2 = nn.Linear(num_channels // reduc_ratio, num_channels,
bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, feat_map):
# attention branch--squeeze operation
gap_out = feat_map.view(feat_map.size()[0], self.num_channels,
-1).mean(dim=2)
#print(gap_out.shape)
# attention branch--excitation operation
fc1_out = self.relu(self.fc1(gap_out))
fc2_out = self.sigmoid(self.fc2(fc1_out))
#print(fc2_out.shape)
# attention operation
fc2_out = fc2_out.view(fc2_out.size()[0], fc2_out.size()[1], 1)
feat_map = torch.mul(feat_map, fc2_out)
return feat_map