Diff of /src/graph.py [000000] .. [a378de]

Switch to unified view

a b/src/graph.py
1
from __future__ import division, print_function
2
from keras.models import Model
3
from keras.layers import Input, Conv1D, Dense, add, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization, Lambda
4
from keras import backend as K
5
from keras.optimizers import Adam
6
from keras.saving import register_keras_serializable
7
import tensorflow as tf
8
9
@register_keras_serializable(package="custom")
10
def zeropad(x):
11
    """ 
12
    zeropad and zeropad_output_shapes are from 
13
    https://github.com/awni/ecg/blob/master/ecg/network.py
14
    """
15
    y = tf.zeros_like(x)
16
    return tf.concat([x, y], axis=2)
17
18
@register_keras_serializable(package="custom")
19
def zeropad_output_shape(input_shape):
20
    shape = list(input_shape)
21
    assert len(shape) == 3
22
    shape[2] *= 2
23
    return tuple(shape)
24
25
26
def ECG_model(config):
27
    """ 
28
    implementation of the model in https://www.nature.com/articles/s41591-018-0268-3 
29
    also have reference to codes at 
30
    https://github.com/awni/ecg/blob/master/ecg/network.py 
31
    and 
32
    https://github.com/fernandoandreotti/cinc-challenge2017/blob/master/deeplearn-approach/train_model.py
33
    """
34
    def first_conv_block(inputs, config):
35
        layer = Conv1D(filters=config.filter_length,
36
               kernel_size=config.kernel_size,
37
               padding='same',
38
               strides=1,
39
               kernel_initializer='he_normal')(inputs)
40
        layer = BatchNormalization()(layer)
41
        layer = Activation('relu')(layer)
42
43
        shortcut = MaxPooling1D(pool_size=1,
44
                      strides=1)(layer)
45
46
        layer =  Conv1D(filters=config.filter_length,
47
               kernel_size=config.kernel_size,
48
               padding='same',
49
               strides=1,
50
               kernel_initializer='he_normal')(layer)
51
        layer = BatchNormalization()(layer)
52
        layer = Activation('relu')(layer)
53
        layer = Dropout(config.drop_rate)(layer)
54
        layer =  Conv1D(filters=config.filter_length,
55
                        kernel_size=config.kernel_size,
56
                        padding='same',
57
                        strides=1,
58
                        kernel_initializer='he_normal')(layer)
59
        return add([shortcut, layer])
60
61
    def main_loop_blocks(layer, config):
62
        filter_length = config.filter_length
63
        n_blocks = 15
64
        for block_index in range(n_blocks):
65
66
            subsample_length = 2 if block_index % 2 == 0 else 1
67
            shortcut = MaxPooling1D(pool_size=subsample_length)(layer)
68
69
            # 5 is chosen instead of 4 from the original model
70
            if block_index % 4 == 0 and block_index > 0 :
71
                # double size of the network and match the shapes of both branches
72
                shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut)
73
                filter_length *= 2
74
75
            layer = BatchNormalization()(layer)
76
            layer = Activation('relu')(layer)
77
            layer =  Conv1D(filters= filter_length,
78
                            kernel_size=config.kernel_size,
79
                            padding='same',
80
                            strides=subsample_length,
81
                            kernel_initializer='he_normal')(layer)
82
            layer = BatchNormalization()(layer)
83
            layer = Activation('relu')(layer)
84
            layer = Dropout(config.drop_rate)(layer)
85
            layer =  Conv1D(filters= filter_length,
86
                            kernel_size=config.kernel_size,
87
                            padding='same',
88
                            strides= 1,
89
                            kernel_initializer='he_normal')(layer)
90
            layer = add([shortcut, layer])
91
        return layer
92
93
    def output_block(layer, config):
94
        layer = BatchNormalization()(layer)
95
        layer = Activation('relu')(layer)
96
        layer = Flatten()(layer)
97
        outputs = Dense(len_classes, activation='softmax')(layer)
98
        model = Model(inputs=inputs, outputs=outputs)
99
        
100
        adam = Adam(learning_rate=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False)
101
        model.compile(optimizer= adam,
102
                  loss='categorical_crossentropy',
103
                  metrics=['accuracy'])
104
        model.summary()
105
        return model
106
107
    classes = ['N','V','/','A','F','~']#,'L','R',f','j','E','a']#,'J','Q','e','S'] are too few or not in the trainset, so excluded out
108
    len_classes = len(classes)
109
110
    inputs = Input(shape=(config.input_size, 1), name='input')
111
    layer = first_conv_block(inputs, config)
112
    layer = main_loop_blocks(layer, config)
113
    return output_block(layer, config)