|
a |
|
b/ecg_classification/models.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class Swish(nn.Module): |
|
|
7 |
def forward(self, x): |
|
|
8 |
return x * torch.sigmoid(x) |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class ConvNormPool(nn.Module): |
|
|
12 |
"""Conv Skip-connection module""" |
|
|
13 |
def __init__( |
|
|
14 |
self, |
|
|
15 |
input_size, |
|
|
16 |
hidden_size, |
|
|
17 |
kernel_size, |
|
|
18 |
norm_type='bachnorm' |
|
|
19 |
): |
|
|
20 |
super().__init__() |
|
|
21 |
|
|
|
22 |
self.kernel_size = kernel_size |
|
|
23 |
self.conv_1 = nn.Conv1d( |
|
|
24 |
in_channels=input_size, |
|
|
25 |
out_channels=hidden_size, |
|
|
26 |
kernel_size=kernel_size |
|
|
27 |
) |
|
|
28 |
self.conv_2 = nn.Conv1d( |
|
|
29 |
in_channels=hidden_size, |
|
|
30 |
out_channels=hidden_size, |
|
|
31 |
kernel_size=kernel_size |
|
|
32 |
) |
|
|
33 |
self.conv_3 = nn.Conv1d( |
|
|
34 |
in_channels=hidden_size, |
|
|
35 |
out_channels=hidden_size, |
|
|
36 |
kernel_size=kernel_size |
|
|
37 |
) |
|
|
38 |
self.swish_1 = Swish() |
|
|
39 |
self.swish_2 = Swish() |
|
|
40 |
self.swish_3 = Swish() |
|
|
41 |
if norm_type == 'group': |
|
|
42 |
self.normalization_1 = nn.GroupNorm( |
|
|
43 |
num_groups=8, |
|
|
44 |
num_channels=hidden_size |
|
|
45 |
) |
|
|
46 |
self.normalization_2 = nn.GroupNorm( |
|
|
47 |
num_groups=8, |
|
|
48 |
num_channels=hidden_size |
|
|
49 |
) |
|
|
50 |
self.normalization_3 = nn.GroupNorm( |
|
|
51 |
num_groups=8, |
|
|
52 |
num_channels=hidden_size |
|
|
53 |
) |
|
|
54 |
else: |
|
|
55 |
self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size) |
|
|
56 |
self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size) |
|
|
57 |
self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size) |
|
|
58 |
|
|
|
59 |
self.pool = nn.MaxPool1d(kernel_size=2) |
|
|
60 |
|
|
|
61 |
def forward(self, input): |
|
|
62 |
conv1 = self.conv_1(input) |
|
|
63 |
x = self.normalization_1(conv1) |
|
|
64 |
x = self.swish_1(x) |
|
|
65 |
x = F.pad(x, pad=(self.kernel_size - 1, 0)) |
|
|
66 |
|
|
|
67 |
x = self.conv_2(x) |
|
|
68 |
x = self.normalization_2(x) |
|
|
69 |
x = self.swish_2(x) |
|
|
70 |
x = F.pad(x, pad=(self.kernel_size - 1, 0)) |
|
|
71 |
|
|
|
72 |
conv3 = self.conv_3(x) |
|
|
73 |
x = self.normalization_3(conv1+conv3) |
|
|
74 |
x = self.swish_3(x) |
|
|
75 |
x = F.pad(x, pad=(self.kernel_size - 1, 0)) |
|
|
76 |
|
|
|
77 |
x = self.pool(x) |
|
|
78 |
return x |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
class RNN(nn.Module): |
|
|
82 |
"""RNN module(cell type lstm or gru)""" |
|
|
83 |
def __init__( |
|
|
84 |
self, |
|
|
85 |
input_size, |
|
|
86 |
hid_size, |
|
|
87 |
num_rnn_layers=1, |
|
|
88 |
dropout_p = 0.2, |
|
|
89 |
bidirectional = False, |
|
|
90 |
rnn_type = 'lstm', |
|
|
91 |
): |
|
|
92 |
super().__init__() |
|
|
93 |
|
|
|
94 |
if rnn_type == 'lstm': |
|
|
95 |
self.rnn_layer = nn.LSTM( |
|
|
96 |
input_size=input_size, |
|
|
97 |
hidden_size=hid_size, |
|
|
98 |
num_layers=num_rnn_layers, |
|
|
99 |
dropout=dropout_p if num_rnn_layers>1 else 0, |
|
|
100 |
bidirectional=bidirectional, |
|
|
101 |
batch_first=True, |
|
|
102 |
) |
|
|
103 |
|
|
|
104 |
else: |
|
|
105 |
self.rnn_layer = nn.GRU( |
|
|
106 |
input_size=input_size, |
|
|
107 |
hidden_size=hid_size, |
|
|
108 |
num_layers=num_rnn_layers, |
|
|
109 |
dropout=dropout_p if num_rnn_layers>1 else 0, |
|
|
110 |
bidirectional=bidirectional, |
|
|
111 |
batch_first=True, |
|
|
112 |
) |
|
|
113 |
def forward(self, input): |
|
|
114 |
outputs, hidden_states = self.rnn_layer(input) |
|
|
115 |
return outputs, hidden_states |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
class CNN(nn.Module): |
|
|
119 |
def __init__( |
|
|
120 |
self, |
|
|
121 |
input_size = 1, |
|
|
122 |
hid_size = 256, |
|
|
123 |
kernel_size = 5, |
|
|
124 |
num_classes = 5, |
|
|
125 |
): |
|
|
126 |
|
|
|
127 |
super().__init__() |
|
|
128 |
|
|
|
129 |
self.conv1 = ConvNormPool( |
|
|
130 |
input_size=input_size, |
|
|
131 |
hidden_size=hid_size, |
|
|
132 |
kernel_size=kernel_size, |
|
|
133 |
) |
|
|
134 |
self.conv2 = ConvNormPool( |
|
|
135 |
input_size=hid_size, |
|
|
136 |
hidden_size=hid_size//2, |
|
|
137 |
kernel_size=kernel_size, |
|
|
138 |
) |
|
|
139 |
self.conv3 = ConvNormPool( |
|
|
140 |
input_size=hid_size//2, |
|
|
141 |
hidden_size=hid_size//4, |
|
|
142 |
kernel_size=kernel_size, |
|
|
143 |
) |
|
|
144 |
self.avgpool = nn.AdaptiveAvgPool1d((1)) |
|
|
145 |
self.fc = nn.Linear(in_features=hid_size//4, out_features=num_classes) |
|
|
146 |
|
|
|
147 |
def forward(self, input): |
|
|
148 |
x = self.conv1(input) |
|
|
149 |
x = self.conv2(x) |
|
|
150 |
x = self.conv3(x) |
|
|
151 |
x = self.avgpool(x) |
|
|
152 |
# print(x.shape) # num_features * num_channels |
|
|
153 |
x = x.view(-1, x.size(1) * x.size(2)) |
|
|
154 |
x = F.softmax(self.fc(x), dim=1) |
|
|
155 |
return x |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
class RNNModel(nn.Module): |
|
|
159 |
def __init__( |
|
|
160 |
self, |
|
|
161 |
input_size, |
|
|
162 |
hid_size, |
|
|
163 |
rnn_type, |
|
|
164 |
bidirectional, |
|
|
165 |
n_classes=5, |
|
|
166 |
kernel_size=5, |
|
|
167 |
): |
|
|
168 |
super().__init__() |
|
|
169 |
|
|
|
170 |
self.rnn_layer = RNN( |
|
|
171 |
input_size=46,#hid_size * 2 if bidirectional else hid_size, |
|
|
172 |
hid_size=hid_size, |
|
|
173 |
rnn_type=rnn_type, |
|
|
174 |
bidirectional=bidirectional |
|
|
175 |
) |
|
|
176 |
self.conv1 = ConvNormPool( |
|
|
177 |
input_size=input_size, |
|
|
178 |
hidden_size=hid_size, |
|
|
179 |
kernel_size=kernel_size, |
|
|
180 |
) |
|
|
181 |
self.conv2 = ConvNormPool( |
|
|
182 |
input_size=hid_size, |
|
|
183 |
hidden_size=hid_size, |
|
|
184 |
kernel_size=kernel_size, |
|
|
185 |
) |
|
|
186 |
self.avgpool = nn.AdaptiveAvgPool1d((1)) |
|
|
187 |
self.fc = nn.Linear(in_features=hid_size, out_features=n_classes) |
|
|
188 |
|
|
|
189 |
def forward(self, input): |
|
|
190 |
x = self.conv1(input) |
|
|
191 |
x = self.conv2(x) |
|
|
192 |
x, _ = self.rnn_layer(x) |
|
|
193 |
x = self.avgpool(x) |
|
|
194 |
x = x.view(-1, x.size(1) * x.size(2)) |
|
|
195 |
x = F.softmax(self.fc(x), dim=1)#.squeeze(1) |
|
|
196 |
return x |
|
|
197 |
|
|
|
198 |
|
|
|
199 |
class RNNAttentionModel(nn.Module): |
|
|
200 |
def __init__( |
|
|
201 |
self, |
|
|
202 |
input_size, |
|
|
203 |
hid_size, |
|
|
204 |
rnn_type, |
|
|
205 |
bidirectional, |
|
|
206 |
n_classes=5, |
|
|
207 |
kernel_size=5, |
|
|
208 |
): |
|
|
209 |
super().__init__() |
|
|
210 |
|
|
|
211 |
self.rnn_layer = RNN( |
|
|
212 |
input_size=46, |
|
|
213 |
hid_size=hid_size, |
|
|
214 |
rnn_type=rnn_type, |
|
|
215 |
bidirectional=bidirectional |
|
|
216 |
) |
|
|
217 |
self.conv1 = ConvNormPool( |
|
|
218 |
input_size=input_size, |
|
|
219 |
hidden_size=hid_size, |
|
|
220 |
kernel_size=kernel_size, |
|
|
221 |
) |
|
|
222 |
self.conv2 = ConvNormPool( |
|
|
223 |
input_size=hid_size, |
|
|
224 |
hidden_size=hid_size, |
|
|
225 |
kernel_size=kernel_size, |
|
|
226 |
) |
|
|
227 |
self.avgpool = nn.AdaptiveMaxPool1d((1)) |
|
|
228 |
self.attn = nn.Linear(hid_size, hid_size, bias=False) |
|
|
229 |
self.fc = nn.Linear(in_features=hid_size, out_features=n_classes) |
|
|
230 |
|
|
|
231 |
def forward(self, input): |
|
|
232 |
x = self.conv1(input) |
|
|
233 |
x = self.conv2(x) |
|
|
234 |
x_out, hid_states = self.rnn_layer(x) |
|
|
235 |
x = torch.cat([hid_states[0], hid_states[1]], dim=0).transpose(0, 1) |
|
|
236 |
x_attn = torch.tanh(self.attn(x)) |
|
|
237 |
x = x_attn.bmm(x_out) |
|
|
238 |
x = x.transpose(2, 1) |
|
|
239 |
x = self.avgpool(x) |
|
|
240 |
x = x.view(-1, x.size(1) * x.size(2)) |
|
|
241 |
x = F.softmax(self.fc(x), dim=-1) |
|
|
242 |
return x |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
if __name__ == '__main__': |
|
|
246 |
rnn_attn = RNNAttentionModel(1, 64, 'lstm', False) |
|
|
247 |
rnn = RNNModel(1, 64, 'lstm', True) |
|
|
248 |
cnn = CNN(num_classes=5, hid_size=128) |