|
a |
|
b/src/graph.py |
|
|
1 |
from __future__ import division, print_function |
|
|
2 |
from keras.models import Model |
|
|
3 |
from keras.layers import Input, Conv1D, Dense, add, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization, Lambda |
|
|
4 |
from keras import backend as K |
|
|
5 |
from keras.optimizers import Adam |
|
|
6 |
from keras.saving import register_keras_serializable |
|
|
7 |
import tensorflow as tf |
|
|
8 |
|
|
|
9 |
@register_keras_serializable(package="custom") |
|
|
10 |
def zeropad(x): |
|
|
11 |
""" |
|
|
12 |
zeropad and zeropad_output_shapes are from |
|
|
13 |
https://github.com/awni/ecg/blob/master/ecg/network.py |
|
|
14 |
""" |
|
|
15 |
y = tf.zeros_like(x) |
|
|
16 |
return tf.concat([x, y], axis=2) |
|
|
17 |
|
|
|
18 |
@register_keras_serializable(package="custom") |
|
|
19 |
def zeropad_output_shape(input_shape): |
|
|
20 |
shape = list(input_shape) |
|
|
21 |
assert len(shape) == 3 |
|
|
22 |
shape[2] *= 2 |
|
|
23 |
return tuple(shape) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def ECG_model(config): |
|
|
27 |
""" |
|
|
28 |
implementation of the model in https://www.nature.com/articles/s41591-018-0268-3 |
|
|
29 |
also have reference to codes at |
|
|
30 |
https://github.com/awni/ecg/blob/master/ecg/network.py |
|
|
31 |
and |
|
|
32 |
https://github.com/fernandoandreotti/cinc-challenge2017/blob/master/deeplearn-approach/train_model.py |
|
|
33 |
""" |
|
|
34 |
def first_conv_block(inputs, config): |
|
|
35 |
layer = Conv1D(filters=config.filter_length, |
|
|
36 |
kernel_size=config.kernel_size, |
|
|
37 |
padding='same', |
|
|
38 |
strides=1, |
|
|
39 |
kernel_initializer='he_normal')(inputs) |
|
|
40 |
layer = BatchNormalization()(layer) |
|
|
41 |
layer = Activation('relu')(layer) |
|
|
42 |
|
|
|
43 |
shortcut = MaxPooling1D(pool_size=1, |
|
|
44 |
strides=1)(layer) |
|
|
45 |
|
|
|
46 |
layer = Conv1D(filters=config.filter_length, |
|
|
47 |
kernel_size=config.kernel_size, |
|
|
48 |
padding='same', |
|
|
49 |
strides=1, |
|
|
50 |
kernel_initializer='he_normal')(layer) |
|
|
51 |
layer = BatchNormalization()(layer) |
|
|
52 |
layer = Activation('relu')(layer) |
|
|
53 |
layer = Dropout(config.drop_rate)(layer) |
|
|
54 |
layer = Conv1D(filters=config.filter_length, |
|
|
55 |
kernel_size=config.kernel_size, |
|
|
56 |
padding='same', |
|
|
57 |
strides=1, |
|
|
58 |
kernel_initializer='he_normal')(layer) |
|
|
59 |
return add([shortcut, layer]) |
|
|
60 |
|
|
|
61 |
def main_loop_blocks(layer, config): |
|
|
62 |
filter_length = config.filter_length |
|
|
63 |
n_blocks = 15 |
|
|
64 |
for block_index in range(n_blocks): |
|
|
65 |
|
|
|
66 |
subsample_length = 2 if block_index % 2 == 0 else 1 |
|
|
67 |
shortcut = MaxPooling1D(pool_size=subsample_length)(layer) |
|
|
68 |
|
|
|
69 |
# 5 is chosen instead of 4 from the original model |
|
|
70 |
if block_index % 4 == 0 and block_index > 0 : |
|
|
71 |
# double size of the network and match the shapes of both branches |
|
|
72 |
shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut) |
|
|
73 |
filter_length *= 2 |
|
|
74 |
|
|
|
75 |
layer = BatchNormalization()(layer) |
|
|
76 |
layer = Activation('relu')(layer) |
|
|
77 |
layer = Conv1D(filters= filter_length, |
|
|
78 |
kernel_size=config.kernel_size, |
|
|
79 |
padding='same', |
|
|
80 |
strides=subsample_length, |
|
|
81 |
kernel_initializer='he_normal')(layer) |
|
|
82 |
layer = BatchNormalization()(layer) |
|
|
83 |
layer = Activation('relu')(layer) |
|
|
84 |
layer = Dropout(config.drop_rate)(layer) |
|
|
85 |
layer = Conv1D(filters= filter_length, |
|
|
86 |
kernel_size=config.kernel_size, |
|
|
87 |
padding='same', |
|
|
88 |
strides= 1, |
|
|
89 |
kernel_initializer='he_normal')(layer) |
|
|
90 |
layer = add([shortcut, layer]) |
|
|
91 |
return layer |
|
|
92 |
|
|
|
93 |
def output_block(layer, config): |
|
|
94 |
layer = BatchNormalization()(layer) |
|
|
95 |
layer = Activation('relu')(layer) |
|
|
96 |
layer = Flatten()(layer) |
|
|
97 |
outputs = Dense(len_classes, activation='softmax')(layer) |
|
|
98 |
model = Model(inputs=inputs, outputs=outputs) |
|
|
99 |
|
|
|
100 |
adam = Adam(learning_rate=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False) |
|
|
101 |
model.compile(optimizer= adam, |
|
|
102 |
loss='categorical_crossentropy', |
|
|
103 |
metrics=['accuracy']) |
|
|
104 |
model.summary() |
|
|
105 |
return model |
|
|
106 |
|
|
|
107 |
classes = ['N','V','/','A','F','~']#,'L','R',f','j','E','a']#,'J','Q','e','S'] are too few or not in the trainset, so excluded out |
|
|
108 |
len_classes = len(classes) |
|
|
109 |
|
|
|
110 |
inputs = Input(shape=(config.input_size, 1), name='input') |
|
|
111 |
layer = first_conv_block(inputs, config) |
|
|
112 |
layer = main_loop_blocks(layer, config) |
|
|
113 |
return output_block(layer, config) |