import torch
import torch.nn as nn
import torch.nn.functional as F
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class ConvNormPool(nn.Module):
"""Conv Skip-connection module"""
def __init__(
self,
input_size,
hidden_size,
kernel_size,
norm_type='bachnorm'
):
super().__init__()
self.kernel_size = kernel_size
self.conv_1 = nn.Conv1d(
in_channels=input_size,
out_channels=hidden_size,
kernel_size=kernel_size
)
self.conv_2 = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size
)
self.conv_3 = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size
)
self.swish_1 = Swish()
self.swish_2 = Swish()
self.swish_3 = Swish()
if norm_type == 'group':
self.normalization_1 = nn.GroupNorm(
num_groups=8,
num_channels=hidden_size
)
self.normalization_2 = nn.GroupNorm(
num_groups=8,
num_channels=hidden_size
)
self.normalization_3 = nn.GroupNorm(
num_groups=8,
num_channels=hidden_size
)
else:
self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size)
self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size)
self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size)
self.pool = nn.MaxPool1d(kernel_size=2)
def forward(self, input):
conv1 = self.conv_1(input)
x = self.normalization_1(conv1)
x = self.swish_1(x)
x = F.pad(x, pad=(self.kernel_size - 1, 0))
x = self.conv_2(x)
x = self.normalization_2(x)
x = self.swish_2(x)
x = F.pad(x, pad=(self.kernel_size - 1, 0))
conv3 = self.conv_3(x)
x = self.normalization_3(conv1+conv3)
x = self.swish_3(x)
x = F.pad(x, pad=(self.kernel_size - 1, 0))
x = self.pool(x)
return x
class RNN(nn.Module):
"""RNN module(cell type lstm or gru)"""
def __init__(
self,
input_size,
hid_size,
num_rnn_layers=1,
dropout_p = 0.2,
bidirectional = False,
rnn_type = 'lstm',
):
super().__init__()
if rnn_type == 'lstm':
self.rnn_layer = nn.LSTM(
input_size=input_size,
hidden_size=hid_size,
num_layers=num_rnn_layers,
dropout=dropout_p if num_rnn_layers>1 else 0,
bidirectional=bidirectional,
batch_first=True,
)
else:
self.rnn_layer = nn.GRU(
input_size=input_size,
hidden_size=hid_size,
num_layers=num_rnn_layers,
dropout=dropout_p if num_rnn_layers>1 else 0,
bidirectional=bidirectional,
batch_first=True,
)
def forward(self, input):
outputs, hidden_states = self.rnn_layer(input)
return outputs, hidden_states
class CNN(nn.Module):
def __init__(
self,
input_size = 1,
hid_size = 256,
kernel_size = 5,
num_classes = 5,
):
super().__init__()
self.conv1 = ConvNormPool(
input_size=input_size,
hidden_size=hid_size,
kernel_size=kernel_size,
)
self.conv2 = ConvNormPool(
input_size=hid_size,
hidden_size=hid_size//2,
kernel_size=kernel_size,
)
self.conv3 = ConvNormPool(
input_size=hid_size//2,
hidden_size=hid_size//4,
kernel_size=kernel_size,
)
self.avgpool = nn.AdaptiveAvgPool1d((1))
self.fc = nn.Linear(in_features=hid_size//4, out_features=num_classes)
def forward(self, input):
x = self.conv1(input)
x = self.conv2(x)
x = self.conv3(x)
x = self.avgpool(x)
# print(x.shape) # num_features * num_channels
x = x.view(-1, x.size(1) * x.size(2))
x = F.softmax(self.fc(x), dim=1)
return x
class RNNModel(nn.Module):
def __init__(
self,
input_size,
hid_size,
rnn_type,
bidirectional,
n_classes=5,
kernel_size=5,
):
super().__init__()
self.rnn_layer = RNN(
input_size=46,#hid_size * 2 if bidirectional else hid_size,
hid_size=hid_size,
rnn_type=rnn_type,
bidirectional=bidirectional
)
self.conv1 = ConvNormPool(
input_size=input_size,
hidden_size=hid_size,
kernel_size=kernel_size,
)
self.conv2 = ConvNormPool(
input_size=hid_size,
hidden_size=hid_size,
kernel_size=kernel_size,
)
self.avgpool = nn.AdaptiveAvgPool1d((1))
self.fc = nn.Linear(in_features=hid_size, out_features=n_classes)
def forward(self, input):
x = self.conv1(input)
x = self.conv2(x)
x, _ = self.rnn_layer(x)
x = self.avgpool(x)
x = x.view(-1, x.size(1) * x.size(2))
x = F.softmax(self.fc(x), dim=1)#.squeeze(1)
return x
class RNNAttentionModel(nn.Module):
def __init__(
self,
input_size,
hid_size,
rnn_type,
bidirectional,
n_classes=5,
kernel_size=5,
):
super().__init__()
self.rnn_layer = RNN(
input_size=46,
hid_size=hid_size,
rnn_type=rnn_type,
bidirectional=bidirectional
)
self.conv1 = ConvNormPool(
input_size=input_size,
hidden_size=hid_size,
kernel_size=kernel_size,
)
self.conv2 = ConvNormPool(
input_size=hid_size,
hidden_size=hid_size,
kernel_size=kernel_size,
)
self.avgpool = nn.AdaptiveMaxPool1d((1))
self.attn = nn.Linear(hid_size, hid_size, bias=False)
self.fc = nn.Linear(in_features=hid_size, out_features=n_classes)
def forward(self, input):
x = self.conv1(input)
x = self.conv2(x)
x_out, hid_states = self.rnn_layer(x)
x = torch.cat([hid_states[0], hid_states[1]], dim=0).transpose(0, 1)
x_attn = torch.tanh(self.attn(x))
x = x_attn.bmm(x_out)
x = x.transpose(2, 1)
x = self.avgpool(x)
x = x.view(-1, x.size(1) * x.size(2))
x = F.softmax(self.fc(x), dim=-1)
return x
if __name__ == '__main__':
rnn_attn = RNNAttentionModel(1, 64, 'lstm', False)
rnn = RNNModel(1, 64, 'lstm', True)
cnn = CNN(num_classes=5, hid_size=128)