a b/BioSeqNet/resnest/gluon/splat.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
import mxnet as mx
9
from mxnet.gluon import nn
10
from mxnet.gluon.nn import Conv1D, Block, HybridBlock, Dense, BatchNorm, Activation
11
12
__all__ = ['SplitAttentionConv']
13
14
USE_BN = True
15
16
class SplitAttentionConv(HybridBlock):
17
    def __init__(self, channels, kernel_size, strides=1, padding=0,
18
                 dilation=1, groups=1, radix=2, *args, in_channels=None, r=2,
19
                 norm_layer=BatchNorm, norm_kwargs=None, drop_ratio=0, **kwargs):
20
        super().__init__()
21
        norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
22
        inter_channels = max(in_channels*radix//2//r, 32)
23
        self.radix = radix
24
        self.cardinality = groups
25
        self.conv = Conv1D(channels*radix, kernel_size, strides, padding, dilation,
26
                           groups=groups*radix, *args, in_channels=in_channels, **kwargs)
27
        if USE_BN:
28
            self.bn = norm_layer(in_channels=channels*radix, **norm_kwargs)
29
        self.relu = Activation('relu')
30
        self.fc1 = Conv1D(inter_channels, 1, in_channels=channels, groups=self.cardinality)
31
        if USE_BN:
32
            self.bn1 = norm_layer(in_channels=inter_channels, **norm_kwargs)
33
        self.relu1 = Activation('relu')
34
        if drop_ratio > 0:
35
            self.drop = nn.Dropout(drop_ratio)
36
        else:
37
            self.drop = None
38
        self.fc2 = Conv1D(channels*radix, 1, in_channels=inter_channels, groups=self.cardinality)
39
        self.channels = channels
40
        self.rsoftmax = rSoftMax(radix, groups)
41
42
    def hybrid_forward(self, F, x):
43
        x = self.conv(x)
44
        if USE_BN:
45
            x = self.bn(x)
46
        x = self.relu(x)
47
        if self.radix > 1:
48
            splited = F.split(x, self.radix, axis=1)
49
            gap = sum(splited)
50
        else:
51
            gap = x
52
        gap = F.contrib.AdaptiveAvgPooling1D(gap, 1)
53
        gap = self.fc1(gap)
54
        if USE_BN:
55
            gap = self.bn1(gap)
56
        atten = self.relu1(gap)
57
        if self.drop:
58
            atten = self.drop(atten)
59
        atten = self.fc2(atten).reshape((0, self.radix, self.channels))
60
        #atten = self.rsoftmax(atten).reshape((0, -1, 1, 1))
61
        atten = self.rsoftmax(atten).reshape((0, -1, 1))
62
        if self.radix > 1:
63
            atten = F.split(atten, self.radix, axis=1)
64
            outs = [F.broadcast_mul(att, split) for (att, split) in zip(atten, splited)]
65
            out = sum(outs)
66
        else:
67
            out = F.broadcast_mul(atten, x)
68
        return out
69
70
71
class rSoftMax(nn.HybridBlock):
72
    def __init__(self, radix, cardinality):
73
        super().__init__()
74
        self.radix = radix
75
        self.cardinality = cardinality
76
77
    def hybrid_forward(self, F, x):
78
        if self.radix > 1:
79
            x = x.reshape((0, self.cardinality, self.radix, -1)).swapaxes(1, 2)
80
            x = F.softmax(x, axis=1)
81
            x = x.reshape((0, -1))
82
        else:
83
            x = F.sigmoid(x)
84
        return x
85