|
a |
|
b/BioSeqNet/resnest/gluon/splat.py |
|
|
1 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
2 |
## Created by: Hang Zhang |
|
|
3 |
## Email: zhanghang0704@gmail.com |
|
|
4 |
## Copyright (c) 2020 |
|
|
5 |
## |
|
|
6 |
## LICENSE file in the root directory of this source tree |
|
|
7 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
8 |
import mxnet as mx |
|
|
9 |
from mxnet.gluon import nn |
|
|
10 |
from mxnet.gluon.nn import Conv1D, Block, HybridBlock, Dense, BatchNorm, Activation |
|
|
11 |
|
|
|
12 |
__all__ = ['SplitAttentionConv'] |
|
|
13 |
|
|
|
14 |
USE_BN = True |
|
|
15 |
|
|
|
16 |
class SplitAttentionConv(HybridBlock): |
|
|
17 |
def __init__(self, channels, kernel_size, strides=1, padding=0, |
|
|
18 |
dilation=1, groups=1, radix=2, *args, in_channels=None, r=2, |
|
|
19 |
norm_layer=BatchNorm, norm_kwargs=None, drop_ratio=0, **kwargs): |
|
|
20 |
super().__init__() |
|
|
21 |
norm_kwargs = norm_kwargs if norm_kwargs is not None else {} |
|
|
22 |
inter_channels = max(in_channels*radix//2//r, 32) |
|
|
23 |
self.radix = radix |
|
|
24 |
self.cardinality = groups |
|
|
25 |
self.conv = Conv1D(channels*radix, kernel_size, strides, padding, dilation, |
|
|
26 |
groups=groups*radix, *args, in_channels=in_channels, **kwargs) |
|
|
27 |
if USE_BN: |
|
|
28 |
self.bn = norm_layer(in_channels=channels*radix, **norm_kwargs) |
|
|
29 |
self.relu = Activation('relu') |
|
|
30 |
self.fc1 = Conv1D(inter_channels, 1, in_channels=channels, groups=self.cardinality) |
|
|
31 |
if USE_BN: |
|
|
32 |
self.bn1 = norm_layer(in_channels=inter_channels, **norm_kwargs) |
|
|
33 |
self.relu1 = Activation('relu') |
|
|
34 |
if drop_ratio > 0: |
|
|
35 |
self.drop = nn.Dropout(drop_ratio) |
|
|
36 |
else: |
|
|
37 |
self.drop = None |
|
|
38 |
self.fc2 = Conv1D(channels*radix, 1, in_channels=inter_channels, groups=self.cardinality) |
|
|
39 |
self.channels = channels |
|
|
40 |
self.rsoftmax = rSoftMax(radix, groups) |
|
|
41 |
|
|
|
42 |
def hybrid_forward(self, F, x): |
|
|
43 |
x = self.conv(x) |
|
|
44 |
if USE_BN: |
|
|
45 |
x = self.bn(x) |
|
|
46 |
x = self.relu(x) |
|
|
47 |
if self.radix > 1: |
|
|
48 |
splited = F.split(x, self.radix, axis=1) |
|
|
49 |
gap = sum(splited) |
|
|
50 |
else: |
|
|
51 |
gap = x |
|
|
52 |
gap = F.contrib.AdaptiveAvgPooling1D(gap, 1) |
|
|
53 |
gap = self.fc1(gap) |
|
|
54 |
if USE_BN: |
|
|
55 |
gap = self.bn1(gap) |
|
|
56 |
atten = self.relu1(gap) |
|
|
57 |
if self.drop: |
|
|
58 |
atten = self.drop(atten) |
|
|
59 |
atten = self.fc2(atten).reshape((0, self.radix, self.channels)) |
|
|
60 |
#atten = self.rsoftmax(atten).reshape((0, -1, 1, 1)) |
|
|
61 |
atten = self.rsoftmax(atten).reshape((0, -1, 1)) |
|
|
62 |
if self.radix > 1: |
|
|
63 |
atten = F.split(atten, self.radix, axis=1) |
|
|
64 |
outs = [F.broadcast_mul(att, split) for (att, split) in zip(atten, splited)] |
|
|
65 |
out = sum(outs) |
|
|
66 |
else: |
|
|
67 |
out = F.broadcast_mul(atten, x) |
|
|
68 |
return out |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
class rSoftMax(nn.HybridBlock): |
|
|
72 |
def __init__(self, radix, cardinality): |
|
|
73 |
super().__init__() |
|
|
74 |
self.radix = radix |
|
|
75 |
self.cardinality = cardinality |
|
|
76 |
|
|
|
77 |
def hybrid_forward(self, F, x): |
|
|
78 |
if self.radix > 1: |
|
|
79 |
x = x.reshape((0, self.cardinality, self.radix, -1)).swapaxes(1, 2) |
|
|
80 |
x = F.softmax(x, axis=1) |
|
|
81 |
x = x.reshape((0, -1)) |
|
|
82 |
else: |
|
|
83 |
x = F.sigmoid(x) |
|
|
84 |
return x |
|
|
85 |
|