a b/survival4D/nn/tf/models.py
1
import keras
2
from typing import Tuple
3
from keras import backend as K
4
from keras.models import Model
5
from keras.layers import Input, BatchNormalization, Activation
6
from keras.layers.core import Dense, Dropout
7
from keras.optimizers import Adam
8
from keras.regularizers import l1
9
10
11
def baseline_autoencoder(
12
    input_shape: Tuple, dropout: float, num_ae_units1: int, num_ae_units2: int, l1_reg_lambda_exp: float,
13
) -> Model:
14
    """Baseline autoencoder as published in https://www.nature.com/articles/s42256-019-0019-2"""
15
    inputvec = Input(shape=(input_shape,))
16
    x = Dropout(dropout, input_shape=(input_shape,))(inputvec)
17
    x = Dense(units=int(num_ae_units1), activation='relu', activity_regularizer=l1(10**l1_reg_lambda_exp))(x)
18
    encoded = Dense(units=int(num_ae_units2), activation='relu', name='encoded')(x)
19
    risk_pred = Dense(units=1,  activation='linear', name='predicted_risk')(encoded)
20
    z = Dense(units=int(num_ae_units1), activation='relu')(encoded)
21
    decoded = Dense(units=input_shape, activation='linear', name='decoded')(z)
22
23
    model = Model(inputs=inputvec, outputs=[decoded, risk_pred])
24
    return model
25
26
27
def baseline_bn_autoencoder(
28
    input_shape: Tuple, dropout: float, num_ae_units1: int, num_ae_units2: int, l1_reg_lambda_exp: float,
29
) -> Model:
30
    """Add batch normalization to each layer before relu activation, based on baseline_autoencoder."""
31
    inputvec = Input(shape=(input_shape,))
32
    x = Dropout(dropout, input_shape=(input_shape,))(inputvec)
33
34
    x = Dense(units=int(num_ae_units1), activation=None, activity_regularizer=l1(10**l1_reg_lambda_exp))(x)
35
    x = BatchNormalization()(x)
36
    x = Activation("relu")(x)
37
38
    x = Dense(units=int(num_ae_units2), activation=None)(x)
39
    x = BatchNormalization()(x)
40
    encoded = Activation("relu", name='encoded')(x)
41
42
    risk_pred = Dense(units=1,  activation='linear', name='predicted_risk')(encoded)
43
44
    x = Dense(units=int(num_ae_units1), activation=None)(encoded)
45
    x = BatchNormalization()(x)
46
    z = Activation("relu")(x)
47
48
    decoded = Dense(units=input_shape, activation='linear', name='decoded')(z)
49
50
    model = Model(inputs=inputvec, outputs=[decoded, risk_pred])
51
    return model
52
53
54
def model3_bn_autoencoder(
55
    input_shape: Tuple, dropout: float, num_ae_units1: int, num_ae_units2: int, num_risk_units: int,
56
    l1_reg_lambda_exp: float,
57
) -> Model:
58
    """
59
    Add one more relu layer between encoded and risk_pred, based on baseline_bn_autoencoder.
60
    Model 3 architecture: https://arxiv.org/pdf/1910.02951v1.pdf
61
    """
62
    inputvec = Input(shape=(input_shape,))
63
    x = Dropout(dropout, input_shape=(input_shape,))(inputvec)
64
65
    x = Dense(units=int(num_ae_units1), activation=None, activity_regularizer=l1(10**l1_reg_lambda_exp))(x)
66
    x = BatchNormalization()(x)
67
    x = Activation("relu")(x)
68
69
    x = Dense(units=int(num_ae_units2), activation=None)(x)
70
    x = BatchNormalization()(x)
71
    encoded = Activation("relu", name='encoded')(x)
72
73
    x = Dense(units=num_risk_units,  activation=None)(encoded)
74
    x = BatchNormalization()(x)
75
    x = Activation("relu", name='encoded')(x)
76
77
    risk_pred = Dense(units=1,  activation='linear', name='predicted_risk')(x)
78
79
    x = Dense(units=int(num_ae_units1), activation=None)(encoded)
80
    x = BatchNormalization()(x)
81
    z = Activation("relu")(x)
82
83
    decoded = Dense(units=input_shape, activation='linear', name='decoded')(z)
84
85
    model = Model(inputs=inputvec, outputs=[decoded, risk_pred])
86
    return model
87
88
89
def deep_model3_bn_autoencoder(
90
    input_shape: Tuple, dropout: float, num_ae_units1: int, num_ae_units2: int, num_ae_units3: int,
91
    num_risk_units: int, l1_reg_lambda_exp: float,
92
) -> Model:
93
    """
94
    Add one more relu layer in autoencoder, based on model3_bn_autoencoder.
95
    Model 3 architecture: https://arxiv.org/pdf/1910.02951v1.pdf
96
    """
97
    inputvec = Input(shape=(input_shape,))
98
    x = Dropout(dropout, input_shape=(input_shape,))(inputvec)
99
100
    x = Dense(units=int(num_ae_units1), activation=None, activity_regularizer=l1(10**l1_reg_lambda_exp))(x)
101
    x = BatchNormalization()(x)
102
    x = Activation("relu")(x)
103
104
    x = Dense(units=int(num_ae_units2), activation=None)(x)
105
    x = BatchNormalization()(x)
106
    x = Activation("relu")(x)
107
108
    x = Dense(units=int(num_ae_units3), activation=None)(x)
109
    x = BatchNormalization()(x)
110
    encoded = Activation("relu", name='encoded')(x)
111
112
    x = Dense(units=num_risk_units,  activation=None)(encoded)
113
    x = BatchNormalization()(x)
114
    x = Activation("relu", name='encoded')(x)
115
116
    risk_pred = Dense(units=1,  activation='linear', name='predicted_risk')(x)
117
118
    x = Dense(units=int(num_ae_units2), activation=None)(encoded)
119
    x = BatchNormalization()(x)
120
    x = Activation("relu")(x)
121
122
    x = Dense(units=int(num_ae_units1), activation=None)(x)
123
    x = BatchNormalization()(x)
124
    z = Activation("relu")(x)
125
126
    decoded = Dense(units=input_shape, activation='linear', name='decoded')(z)
127
128
    model = Model(inputs=inputvec, outputs=[decoded, risk_pred])
129
    return model
130
131
132
def model_factory(model_name: str, **kwargs):
133
    # Before defining network architecture, clear current computation graph (if one exists)
134
    K.clear_session()
135
    if model_name == "baseline_autoencoder":
136
        model = baseline_autoencoder(**kwargs)
137
    elif model_name == "baseline_bn_autoencoder":
138
        model = baseline_bn_autoencoder(**kwargs)
139
    elif model_name == "model3_bn_autoencoder":
140
        model = model3_bn_autoencoder(**kwargs)
141
    elif model_name == "deep_model3_bn_autoencoder":
142
        model = deep_model3_bn_autoencoder(**kwargs)
143
    else:
144
        raise ValueError("Model name {} has not been implemented.".format(model_name))
145
    return model