Switch to side-by-side view

--- a
+++ b/BioSeqNet/resnest/torch/splat.py
@@ -0,0 +1,95 @@
+"""Split-Attention"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import Conv1d, Module, Linear, BatchNorm1d, ReLU
+from torch.nn.modules.utils import _pair
+
+__all__ = ['SplAtConv1d']
+
+class SplAtConv1d(Module):
+    """Split-Attention Conv1d
+    """
+    def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0,
+                 dilation=1, groups=1, bias=True,
+                 radix=2, reduction_factor=4,
+                 rectify=False, rectify_avg=False, norm_layer=None,
+                 dropblock_prob=0.0, **kwargs):
+        super(SplAtConv1d, self).__init__()
+        #padding = _pair(padding)
+        #self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
+        self.rectify = rectify and (padding > 0)
+        self.rectify_avg = rectify_avg
+        inter_channels = max(in_channels*radix//reduction_factor, 32)
+        self.radix = radix
+        self.cardinality = groups
+        self.channels = channels
+        self.dropblock_prob = dropblock_prob
+        if self.rectify:
+            from rfconv import RFConv1d
+            self.conv = RFConv1d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
+                                 groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
+        else:
+            self.conv = Conv1d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
+                               groups=groups*radix, bias=bias, **kwargs)
+        self.use_bn = norm_layer is not None
+        if self.use_bn:
+            self.bn0 = norm_layer(channels*radix)
+        self.relu = ReLU(inplace=True)
+        self.fc1 = Conv1d(channels, inter_channels, 1, groups=self.cardinality)
+        if self.use_bn:
+            self.bn1 = norm_layer(inter_channels)
+        self.fc2 = Conv1d(inter_channels, channels*radix, 1, groups=self.cardinality)
+        if dropblock_prob > 0.0:
+            self.dropblock = DropBlock2D(dropblock_prob, 3)
+        self.rsoftmax = rSoftMax(radix, groups)
+
+    def forward(self, x):
+        x = self.conv(x)
+        if self.use_bn:
+            x = self.bn0(x)
+        if self.dropblock_prob > 0.0:
+            x = self.dropblock(x)
+        x = self.relu(x)
+
+        batch, rchannel = x.shape[:2]
+        if self.radix > 1:
+            splited = torch.split(x, rchannel//self.radix, dim=1)
+            gap = sum(splited) 
+        else:
+            gap = x
+        gap = F.adaptive_avg_pool1d(gap, 1)
+        gap = self.fc1(gap)
+
+        if self.use_bn:
+            gap = self.bn1(gap)
+        gap = self.relu(gap)
+
+        atten = self.fc2(gap)
+        #atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+        atten = self.rsoftmax(atten).view(batch, -1, 1)
+
+        if self.radix > 1:
+            attens = torch.split(atten, rchannel//self.radix, dim=1)
+            out = sum([att*split for (att, split) in zip(attens, splited)])
+        else:
+            out = atten * x
+        return out.contiguous()
+
+class rSoftMax(nn.Module):
+    def __init__(self, radix, cardinality):
+        super().__init__()
+        self.radix = radix
+        self.cardinality = cardinality
+
+    def forward(self, x):
+        batch = x.size(0)
+        if self.radix > 1:
+            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
+            x = F.softmax(x, dim=1)
+            x = x.reshape(batch, -1)
+        else:
+            x = torch.sigmoid(x)
+        return x
+