a b/ecg_classification/models.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
6
class Swish(nn.Module):
7
    def forward(self, x):
8
        return x * torch.sigmoid(x)
9
10
11
class ConvNormPool(nn.Module):
12
    """Conv Skip-connection module"""
13
    def __init__(
14
        self,
15
        input_size,
16
        hidden_size,
17
        kernel_size,
18
        norm_type='bachnorm'
19
    ):
20
        super().__init__()
21
        
22
        self.kernel_size = kernel_size
23
        self.conv_1 = nn.Conv1d(
24
            in_channels=input_size,
25
            out_channels=hidden_size,
26
            kernel_size=kernel_size
27
        )
28
        self.conv_2 = nn.Conv1d(
29
            in_channels=hidden_size,
30
            out_channels=hidden_size,
31
            kernel_size=kernel_size
32
        )
33
        self.conv_3 = nn.Conv1d(
34
            in_channels=hidden_size,
35
            out_channels=hidden_size,
36
            kernel_size=kernel_size
37
        )
38
        self.swish_1 = Swish()
39
        self.swish_2 = Swish()
40
        self.swish_3 = Swish()
41
        if norm_type == 'group':
42
            self.normalization_1 = nn.GroupNorm(
43
                num_groups=8,
44
                num_channels=hidden_size
45
            )
46
            self.normalization_2 = nn.GroupNorm(
47
                num_groups=8,
48
                num_channels=hidden_size
49
            )
50
            self.normalization_3 = nn.GroupNorm(
51
                num_groups=8,
52
                num_channels=hidden_size
53
            )
54
        else:
55
            self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size)
56
            self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size)
57
            self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size)
58
            
59
        self.pool = nn.MaxPool1d(kernel_size=2)
60
        
61
    def forward(self, input):
62
        conv1 = self.conv_1(input)
63
        x = self.normalization_1(conv1)
64
        x = self.swish_1(x)
65
        x = F.pad(x, pad=(self.kernel_size - 1, 0))
66
        
67
        x = self.conv_2(x)
68
        x = self.normalization_2(x)
69
        x = self.swish_2(x)
70
        x = F.pad(x, pad=(self.kernel_size - 1, 0))
71
        
72
        conv3 = self.conv_3(x)
73
        x = self.normalization_3(conv1+conv3)
74
        x = self.swish_3(x)
75
        x = F.pad(x, pad=(self.kernel_size - 1, 0))   
76
        
77
        x = self.pool(x)
78
        return x
79
      
80
      
81
class RNN(nn.Module):
82
    """RNN module(cell type lstm or gru)"""
83
    def __init__(
84
        self,
85
        input_size,
86
        hid_size,
87
        num_rnn_layers=1,
88
        dropout_p = 0.2,
89
        bidirectional = False,
90
        rnn_type = 'lstm',
91
    ):
92
        super().__init__()
93
        
94
        if rnn_type == 'lstm':
95
            self.rnn_layer = nn.LSTM(
96
                input_size=input_size,
97
                hidden_size=hid_size,
98
                num_layers=num_rnn_layers,
99
                dropout=dropout_p if num_rnn_layers>1 else 0,
100
                bidirectional=bidirectional,
101
                batch_first=True,
102
            )
103
            
104
        else:
105
            self.rnn_layer = nn.GRU(
106
                input_size=input_size,
107
                hidden_size=hid_size,
108
                num_layers=num_rnn_layers,
109
                dropout=dropout_p if num_rnn_layers>1 else 0,
110
                bidirectional=bidirectional,
111
                batch_first=True,
112
            )
113
    def forward(self, input):
114
        outputs, hidden_states = self.rnn_layer(input)
115
        return outputs, hidden_states
116
      
117
      
118
class CNN(nn.Module):
119
    def __init__(
120
        self,
121
        input_size = 1,
122
        hid_size = 256,
123
        kernel_size = 5,
124
        num_classes = 5,
125
    ):
126
        
127
        super().__init__()
128
        
129
        self.conv1 = ConvNormPool(
130
            input_size=input_size,
131
            hidden_size=hid_size,
132
            kernel_size=kernel_size,
133
        )
134
        self.conv2 = ConvNormPool(
135
            input_size=hid_size,
136
            hidden_size=hid_size//2,
137
            kernel_size=kernel_size,
138
        )
139
        self.conv3 = ConvNormPool(
140
            input_size=hid_size//2,
141
            hidden_size=hid_size//4,
142
            kernel_size=kernel_size,
143
        )
144
        self.avgpool = nn.AdaptiveAvgPool1d((1))
145
        self.fc = nn.Linear(in_features=hid_size//4, out_features=num_classes)
146
        
147
    def forward(self, input):
148
        x = self.conv1(input)
149
        x = self.conv2(x)
150
        x = self.conv3(x)
151
        x = self.avgpool(x)        
152
        # print(x.shape) # num_features * num_channels
153
        x = x.view(-1, x.size(1) * x.size(2))
154
        x = F.softmax(self.fc(x), dim=1)
155
        return x
156
      
157
      
158
class RNNModel(nn.Module):
159
    def __init__(
160
        self,
161
        input_size,
162
        hid_size,
163
        rnn_type,
164
        bidirectional,
165
        n_classes=5,
166
        kernel_size=5,
167
    ):
168
        super().__init__()
169
            
170
        self.rnn_layer = RNN(
171
            input_size=46,#hid_size * 2 if bidirectional else hid_size,
172
            hid_size=hid_size,
173
            rnn_type=rnn_type,
174
            bidirectional=bidirectional
175
        )
176
        self.conv1 = ConvNormPool(
177
            input_size=input_size,
178
            hidden_size=hid_size,
179
            kernel_size=kernel_size,
180
        )
181
        self.conv2 = ConvNormPool(
182
            input_size=hid_size,
183
            hidden_size=hid_size,
184
            kernel_size=kernel_size,
185
        )
186
        self.avgpool = nn.AdaptiveAvgPool1d((1))
187
        self.fc = nn.Linear(in_features=hid_size, out_features=n_classes)
188
189
    def forward(self, input):
190
        x = self.conv1(input)
191
        x = self.conv2(x)
192
        x, _ = self.rnn_layer(x)
193
        x = self.avgpool(x)
194
        x = x.view(-1, x.size(1) * x.size(2))
195
        x = F.softmax(self.fc(x), dim=1)#.squeeze(1)
196
        return x
197
      
198
      
199
class RNNAttentionModel(nn.Module):
200
    def __init__(
201
        self,
202
        input_size,
203
        hid_size,
204
        rnn_type,
205
        bidirectional,
206
        n_classes=5,
207
        kernel_size=5,
208
    ):
209
        super().__init__()
210
 
211
        self.rnn_layer = RNN(
212
            input_size=46,
213
            hid_size=hid_size,
214
            rnn_type=rnn_type,
215
            bidirectional=bidirectional
216
        )
217
        self.conv1 = ConvNormPool(
218
            input_size=input_size,
219
            hidden_size=hid_size,
220
            kernel_size=kernel_size,
221
        )
222
        self.conv2 = ConvNormPool(
223
            input_size=hid_size,
224
            hidden_size=hid_size,
225
            kernel_size=kernel_size,
226
        )
227
        self.avgpool = nn.AdaptiveMaxPool1d((1))
228
        self.attn = nn.Linear(hid_size, hid_size, bias=False)
229
        self.fc = nn.Linear(in_features=hid_size, out_features=n_classes)
230
        
231
    def forward(self, input):
232
        x = self.conv1(input)
233
        x = self.conv2(x)
234
        x_out, hid_states = self.rnn_layer(x)
235
        x = torch.cat([hid_states[0], hid_states[1]], dim=0).transpose(0, 1)
236
        x_attn = torch.tanh(self.attn(x))
237
        x = x_attn.bmm(x_out)
238
        x = x.transpose(2, 1)
239
        x = self.avgpool(x)
240
        x = x.view(-1, x.size(1) * x.size(2))
241
        x = F.softmax(self.fc(x), dim=-1)
242
        return x
243
      
244
      
245
if __name__ == '__main__':
246
    rnn_attn = RNNAttentionModel(1, 64, 'lstm', False)
247
    rnn = RNNModel(1, 64, 'lstm', True)
248
    cnn = CNN(num_classes=5, hid_size=128)