--- a +++ b/ecg_classification/models.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class ConvNormPool(nn.Module): + """Conv Skip-connection module""" + def __init__( + self, + input_size, + hidden_size, + kernel_size, + norm_type='bachnorm' + ): + super().__init__() + + self.kernel_size = kernel_size + self.conv_1 = nn.Conv1d( + in_channels=input_size, + out_channels=hidden_size, + kernel_size=kernel_size + ) + self.conv_2 = nn.Conv1d( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size + ) + self.conv_3 = nn.Conv1d( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size + ) + self.swish_1 = Swish() + self.swish_2 = Swish() + self.swish_3 = Swish() + if norm_type == 'group': + self.normalization_1 = nn.GroupNorm( + num_groups=8, + num_channels=hidden_size + ) + self.normalization_2 = nn.GroupNorm( + num_groups=8, + num_channels=hidden_size + ) + self.normalization_3 = nn.GroupNorm( + num_groups=8, + num_channels=hidden_size + ) + else: + self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size) + self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size) + self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size) + + self.pool = nn.MaxPool1d(kernel_size=2) + + def forward(self, input): + conv1 = self.conv_1(input) + x = self.normalization_1(conv1) + x = self.swish_1(x) + x = F.pad(x, pad=(self.kernel_size - 1, 0)) + + x = self.conv_2(x) + x = self.normalization_2(x) + x = self.swish_2(x) + x = F.pad(x, pad=(self.kernel_size - 1, 0)) + + conv3 = self.conv_3(x) + x = self.normalization_3(conv1+conv3) + x = self.swish_3(x) + x = F.pad(x, pad=(self.kernel_size - 1, 0)) + + x = self.pool(x) + return x + + +class RNN(nn.Module): + """RNN module(cell type lstm or gru)""" + def __init__( + self, + input_size, + hid_size, + num_rnn_layers=1, + dropout_p = 0.2, + bidirectional = False, + rnn_type = 'lstm', + ): + super().__init__() + + if rnn_type == 'lstm': + self.rnn_layer = nn.LSTM( + input_size=input_size, + hidden_size=hid_size, + num_layers=num_rnn_layers, + dropout=dropout_p if num_rnn_layers>1 else 0, + bidirectional=bidirectional, + batch_first=True, + ) + + else: + self.rnn_layer = nn.GRU( + input_size=input_size, + hidden_size=hid_size, + num_layers=num_rnn_layers, + dropout=dropout_p if num_rnn_layers>1 else 0, + bidirectional=bidirectional, + batch_first=True, + ) + def forward(self, input): + outputs, hidden_states = self.rnn_layer(input) + return outputs, hidden_states + + +class CNN(nn.Module): + def __init__( + self, + input_size = 1, + hid_size = 256, + kernel_size = 5, + num_classes = 5, + ): + + super().__init__() + + self.conv1 = ConvNormPool( + input_size=input_size, + hidden_size=hid_size, + kernel_size=kernel_size, + ) + self.conv2 = ConvNormPool( + input_size=hid_size, + hidden_size=hid_size//2, + kernel_size=kernel_size, + ) + self.conv3 = ConvNormPool( + input_size=hid_size//2, + hidden_size=hid_size//4, + kernel_size=kernel_size, + ) + self.avgpool = nn.AdaptiveAvgPool1d((1)) + self.fc = nn.Linear(in_features=hid_size//4, out_features=num_classes) + + def forward(self, input): + x = self.conv1(input) + x = self.conv2(x) + x = self.conv3(x) + x = self.avgpool(x) + # print(x.shape) # num_features * num_channels + x = x.view(-1, x.size(1) * x.size(2)) + x = F.softmax(self.fc(x), dim=1) + return x + + +class RNNModel(nn.Module): + def __init__( + self, + input_size, + hid_size, + rnn_type, + bidirectional, + n_classes=5, + kernel_size=5, + ): + super().__init__() + + self.rnn_layer = RNN( + input_size=46,#hid_size * 2 if bidirectional else hid_size, + hid_size=hid_size, + rnn_type=rnn_type, + bidirectional=bidirectional + ) + self.conv1 = ConvNormPool( + input_size=input_size, + hidden_size=hid_size, + kernel_size=kernel_size, + ) + self.conv2 = ConvNormPool( + input_size=hid_size, + hidden_size=hid_size, + kernel_size=kernel_size, + ) + self.avgpool = nn.AdaptiveAvgPool1d((1)) + self.fc = nn.Linear(in_features=hid_size, out_features=n_classes) + + def forward(self, input): + x = self.conv1(input) + x = self.conv2(x) + x, _ = self.rnn_layer(x) + x = self.avgpool(x) + x = x.view(-1, x.size(1) * x.size(2)) + x = F.softmax(self.fc(x), dim=1)#.squeeze(1) + return x + + +class RNNAttentionModel(nn.Module): + def __init__( + self, + input_size, + hid_size, + rnn_type, + bidirectional, + n_classes=5, + kernel_size=5, + ): + super().__init__() + + self.rnn_layer = RNN( + input_size=46, + hid_size=hid_size, + rnn_type=rnn_type, + bidirectional=bidirectional + ) + self.conv1 = ConvNormPool( + input_size=input_size, + hidden_size=hid_size, + kernel_size=kernel_size, + ) + self.conv2 = ConvNormPool( + input_size=hid_size, + hidden_size=hid_size, + kernel_size=kernel_size, + ) + self.avgpool = nn.AdaptiveMaxPool1d((1)) + self.attn = nn.Linear(hid_size, hid_size, bias=False) + self.fc = nn.Linear(in_features=hid_size, out_features=n_classes) + + def forward(self, input): + x = self.conv1(input) + x = self.conv2(x) + x_out, hid_states = self.rnn_layer(x) + x = torch.cat([hid_states[0], hid_states[1]], dim=0).transpose(0, 1) + x_attn = torch.tanh(self.attn(x)) + x = x_attn.bmm(x_out) + x = x.transpose(2, 1) + x = self.avgpool(x) + x = x.view(-1, x.size(1) * x.size(2)) + x = F.softmax(self.fc(x), dim=-1) + return x + + +if __name__ == '__main__': + rnn_attn = RNNAttentionModel(1, 64, 'lstm', False) + rnn = RNNModel(1, 64, 'lstm', True) + cnn = CNN(num_classes=5, hid_size=128)