[5d8f6c]: / CaraNet / lib / axial_atten.py

Download this file

30 lines (24 with data), 816 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 10 17:17:13 2021
@author: angelou
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.conv_layer import Conv
from lib.self_attention import self_attn
import math
class AA_kernel(nn.Module):
def __init__(self, in_channel, out_channel):
super(AA_kernel, self).__init__()
self.conv0 = Conv(in_channel, out_channel, kSize=1,stride=1,padding=0)
self.conv1 = Conv(out_channel, out_channel, kSize=(3, 3),stride = 1, padding=1)
self.Hattn = self_attn(out_channel, mode='h')
self.Wattn = self_attn(out_channel, mode='w')
def forward(self, x):
x = self.conv0(x)
x = self.conv1(x)
Hx = self.Hattn(x)
Wx = self.Wattn(Hx)
return Wx