Diff of /code/mlc-ae.py [000000] .. [2979df]

Switch to unified view

a b/code/mlc-ae.py
1
import tensorflow as tf
2
from tensorflow.keras.layers import Input, Dense, BatchNormalization, GaussianNoise, GaussianDropout, Conv1D, multiply
3
from tensorflow.keras.models import Model
4
import tensorflow.keras.backend as backend
5
import numpy as np
6
from sklearn.preprocessing import LabelEncoder, normalize, RobustScaler
7
import sklearn.metrics as sk
8
from sklearn.metrics import mean_absolute_error as mae
9
from sklearn.metrics import mean_squared_error as mse
10
import tensorflow.keras.utils as np_utils
11
from tensorflow.keras.callbacks import CSVLogger, History
12
import data_provider
13
import os
14
from options import opt, MODEL_DIR
15
import time
16
import shap
17
import pandas as pd
18
19
20
def mlc_ae(training=False):
21
    load_file = '../data_process/tcga_pfi.h5' if opt.pfi else '../data_process/tcga.h5'
22
    m_rna, label, gene, sample_id = data_provider.load_h5_all(load_file, True)
23
    if not opt.pfi:
24
        m_rna_o, label_o, gene_o, sample_id_o = data_provider.load_h5_all('../data_process/other.h5', True)
25
        m_rna, label, sample_id = np.concatenate((m_rna, m_rna_o)), np.concatenate((label, label_o)), \
26
                                  np.concatenate((sample_id, sample_id_o))
27
    m_rna = normalize(X=m_rna, axis=0, norm="max")
28
    print('feat and label size', m_rna.shape, label)
29
    """first random: train and test sets"""
30
    indices = np.arange(m_rna.shape[0])
31
    np.random.seed(1)
32
    np.random.shuffle(indices)
33
    m_rna2 = m_rna[indices]
34
    label2 = label[indices]
35
    sample_id = sample_id[indices]
36
    if opt.use_all:
37
        categorical_label = np_utils.to_categorical(label2, num_classes=6)
38
    else:
39
        categorical_label = np_utils.to_categorical(label2, num_classes=2)
40
    """to save the data split strategy for other analysis"""
41
    # tofile = np.stack((sample_id.astype(str), label2.astype(str)), axis=1)
42
    # np.savetxt(X=tofile, fname=MODEL_DIR + "/sample_id.txt", delimiter=",", fmt='%s')
43
44
    m_rna_train = m_rna2[:-opt.test_size, ]
45
    m_rna_test = m_rna2[-opt.test_size:, ]
46
    categorical_label_train = categorical_label[:-opt.test_size, ]
47
    categorical_label_test = categorical_label[-opt.test_size:, ]
48
    label_train = label2[:-opt.test_size, ]
49
    label_test = label2[-opt.test_size:, ]
50
51
    # pr sample operations
52
    pr_idx_train = np.array([i for i, e in enumerate(label_train) if e == 2 or e == 3])
53
    pr_idx_test = np.array([i for i, e in enumerate(label_test) if e == 2 or e == 3])
54
    print('healthy samples in training set:', sum([1 for x in label_train if x == 3]))
55
    print('healthy samples in testing set:', sum([1 for x in label_test if x == 3]))
56
57
    pr_m_rna_train = m_rna_train[pr_idx_train]
58
    pr_m_rna_test = m_rna_test[pr_idx_test]
59
    pr_label_train = label_train[pr_idx_train] - 2
60
    pr_label_test = label_test[pr_idx_test] - 2
61
62
    print('PR train and test size:', pr_label_train.shape, pr_idx_test.shape)
63
    pr_categorical_label_train = np_utils.to_categorical(pr_label_train)
64
    pr_categorical_label_test = np_utils.to_categorical(pr_label_test)
65
    print("data loading has just been finished")
66
67
    def create_model():
68
        inputs = Input(shape=(m_rna.shape[1],), name="inputs")
69
        inputs_0 = BatchNormalization(name="inputs_0")(inputs)
70
        inputs_1 = Dense(1024, activation="relu", name="inputs_1")(inputs_0)
71
        inputs_2 = BatchNormalization(name="inputs_2")(inputs_1)
72
        inputs_3 = Dense(256, activation="relu", name="inputs_3")(inputs_2)
73
        inputs_4 = BatchNormalization(name="inputs_4")(inputs_3)
74
        encoded = Dense(units=12, activation='relu', name='encoded')(inputs_4)
75
        inputs_5 = Dense(512, activation="relu", name="inputs_5")(encoded)
76
        decoded_tcga = Dense(units=m_rna.shape[1], activation='linear', name="m_rna")(inputs_5)
77
        if opt.use_all:
78
            cl_0 = Dense(units=categorical_label_train.shape[1], activation="softmax", name="category")(encoded)
79
        else:
80
            cl_0 = Dense(units=pr_categorical_label_train.shape[1], activation="softmax", name="category")(encoded)
81
        m = Model(inputs=inputs, outputs=[decoded_tcga, cl_0])
82
        m.compile(optimizer='adam',
83
                     loss=["mse", "cosine_similarity"],     # "cosine_similarity"],
84
                     loss_weights=[0.001, 0.5],     # , 0.5
85
                     metrics={"m_rna": ["mae", "mse"], "category": "acc"})     # , "cl_disease": "acc"
86
87
        return m
88
89
    model = create_model()
90
    checkpoint_path = os.path.join(MODEL_DIR, 'my_model.h5')
91
    # model.summary()
92
    if training:
93
        # file_writer = tf.summary.create_file_writer(MODEL_DIR + "/metrics")
94
        # file_writer.set_as_default()
95
        def lr_scheduler(epoch):
96
            lr = 0.01
97
            if epoch < 200:
98
                lr *= 0.99999999
99
            else:
100
                lr *= 0.99999
101
            tf.summary.scalar('Learning Rate', data=lr, step=epoch)
102
            return lr
103
104
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=MODEL_DIR)
105
        # lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
106
        if opt.use_all:
107
            model.fit(m_rna_train, [m_rna_train, categorical_label_train], batch_size=opt.batch_size,
108
                      epochs=opt.max_epoch,
109
                      callbacks=[tensorboard_callback],
110
                      validation_data=(m_rna_test, [m_rna_test, categorical_label_test]),
111
                      verbose=2)
112
        else:
113
            model.fit(pr_m_rna_train, [pr_m_rna_train, pr_categorical_label_train], batch_size=opt.batch_size,
114
                      epochs=opt.max_epoch,
115
                      callbacks=[tensorboard_callback],
116
                      validation_data=(pr_m_rna_test, [pr_m_rna_test, pr_categorical_label_test]),
117
                      verbose=2)
118
        model.save_weights(filepath=checkpoint_path)
119
        print("fitting has just been finished")
120
121
    else:
122
        model.load_weights(checkpoint_path)
123
        if not opt.use_all:
124
            m_rna_test = pr_m_rna_test
125
            categorical_label_test = pr_categorical_label_test
126
        data_pred = model.predict(m_rna_test, batch_size=opt.batch_size, verbose=2)
127
128
        """ argmax """
129
        y_pred = np.argmax(data_pred[1], axis=1)
130
        y_gt = label_test # np.argmax(categorical_label_test, axis=1)
131
        if opt.use_all:
132
            confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1, 2, 3, 4, 5])
133
        else:
134
            confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1])
135
        print(confusion_0)
136
        balanced_acc = sk.balanced_accuracy_score(y_gt, y_pred)
137
        acc = sk.accuracy_score(y_gt, y_pred)
138
        """ logits for roc """
139
        y_logit = data_pred[1][:, 0]
140
        x_logit = categorical_label_test[:, 0]
141
        """save for records"""
142
        np.savetxt(X=m_rna_test, fname=MODEL_DIR + "/test_gene.csv", delimiter=",", fmt='%1.3f')
143
        np.savetxt(X=label, fname=MODEL_DIR + "/label.csv", delimiter=",", fmt='%1.3f')
144
        np.savetxt(X=data_pred[0], fname=MODEL_DIR + "/pred_gene.csv", delimiter=",", fmt='%1.3f')
145
        np.savetxt(X=y_pred, fname=MODEL_DIR + "/pred_label.csv", delimiter=",", fmt='%1.3f')
146
        """ get latent representation """
147
        layer_name = "encoded"
148
        encoded_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
149
        encoded_output = encoded_layer_model.predict(m_rna_test)
150
        np.savetxt(X=encoded_output, fname=MODEL_DIR + "/latent_feat.csv", delimiter=",")
151
152
        """ log """
153
        log1 = open(os.path.join(MODEL_DIR, 'log_blacc.txt'), 'a')
154
        log2 = open(os.path.join(MODEL_DIR, 'log_roc.txt'), 'a')
155
        log3 = open(os.path.join(MODEL_DIR, 'log_acc.txt'), 'a')
156
157
        if not opt.use_all:
158
            """Only execute when using PR data (two labels)"""
159
            auc = sk.roc_auc_score(x_logit, y_logit)
160
            fpr, tpr, thresh = sk.roc_curve(x_logit, y_logit)
161
            roc_feat = {'auc': auc, 'fpr': [], 'tpr': []}
162
            for e in fpr:
163
                roc_feat['fpr'].append(str(e))
164
            for e in tpr:
165
                roc_feat['tpr'].append(str(e))
166
167
        if opt.use_shap:
168
            """ Depth Explainer """
169
            print("Processing SHAP...")
170
            model.load_weights(checkpoint_path)
171
            from matplotlib import colors as plt_colors
172
            layer_name = "category"
173
            encoded_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
174
            input_feat = m_rna_train if opt.sample_all == 'all' else pr_m_rna_train
175
            e = shap.GradientExplainer(encoded_layer_model, input_feat)
176
            # e = shap.DeepExplainer(encoded_layer_model, input_feat)
177
            shap_values = e.shap_values(input_feat)
178
            shap_load_gene = '../data_process/gene_pfi.csv' if opt.pfi else '../data_process/gene.csv'
179
            feat_name = np.loadtxt(shap_load_gene, dtype=str, delimiter=",")   #[:, 3]
180
            class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))])
181
            print('class_inds', class_inds)
182
            colors = np.array(['yellowgreen', 'palevioletred', 'lightcoral', 'mediumpurple', 'cornflowerblue',
183
                               'orange'])[class_inds]
184
            cmap = plt_colors.ListedColormap(colors)
185
            shap.summary_plot(shap_values, input_feat, feature_names=feat_name, max_display=40,
186
                              plot_size=(12.0, 16.0, 2.0), plot_type='bar',
187
                              color=cmap, show=True, sort=True,
188
                              class_names=['Ovarian (T)', 'Ovarian (N)', 'Prostate (T)', 'Prostate (N)', 'Breast (T)',
189
                                           'Breast (N)'])
190
191
        def log_string(out1, out2, out3):
192
            log1.write(str(out1))
193
            log1.write('\n')
194
            log1.flush()
195
            print(out1)
196
            if out2:
197
                log2.write(str(out2['auc']))
198
                log2.write('\n')
199
                roc_x, roc_y = ' '.join(out2['fpr']), ' '.join(out2['tpr'])
200
                log2.write(roc_x)
201
                log2.write('\n')
202
                log2.write(roc_y)
203
                log2.write('\n')
204
                log2.flush()
205
                print(out2)
206
            if out3:
207
                log3.write(str(out3))
208
                log3.write('\n')
209
                log3.flush()
210
211
        if opt.use_argmax:
212
            if opt.use_all:
213
                confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 36)))
214
                log_string(confusion_1, None, None)
215
            else:
216
                confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 4)))
217
                log_string(confusion_1, roc_feat, acc)
218
        else:
219
            log_string(balanced_acc, None, acc)
220
221
222
if __name__ == '__main__':
223
    if opt.phase == 'train':
224
        if not os.path.exists(os.path.join(MODEL_DIR, 'code/')):
225
            os.makedirs(os.path.join(MODEL_DIR, 'code/'))
226
            os.system('cp -r * %s' % (os.path.join(MODEL_DIR, 'code/')))  # bkp of model def
227
        mlc_ae(True)
228
    elif opt.phase == 'test':
229
        mlc_ae(False)