a b/code/eval-classifier.py
1
import tensorflow as tf
2
from tensorflow.keras.layers import Input, Dense, BatchNormalization, GaussianNoise, GaussianDropout, Conv1D, multiply
3
from tensorflow.keras.models import Model
4
import tensorflow.keras.backend as backend
5
import numpy as np
6
from sklearn.preprocessing import LabelEncoder, normalize
7
import sklearn.metrics as sk
8
from sklearn.metrics import mean_absolute_error as mae
9
from sklearn.metrics import mean_squared_error as mse
10
import tensorflow.keras.utils as np_utils
11
from tensorflow.keras.callbacks import CSVLogger, History
12
import data_provider
13
import os
14
from options import opt, MODEL_DIR
15
import time
16
import shap
17
18
19
def eval_encoder(training=False):
20
    load_file = '../data_process/tcga_pfi.h5' if opt.pfi else '../data_process/tcga.h5'
21
    m_rna, label, gene, sample_id = data_provider.load_h5_all(load_file, True)
22
    if not opt.pfi:
23
        m_rna_o, label_o, gene_o, sample_id_o = data_provider.load_h5_all('../data_process/other.h5', True)
24
        m_rna, label, sample_id = np.concatenate((m_rna, m_rna_o)), np.concatenate((label, label_o)), np.concatenate((sample_id, sample_id_o))
25
    m_rna = normalize(X=m_rna, axis=0, norm="max")
26
    print('Data size:', m_rna.shape, label)
27
28
    if not opt.use_all:
29
        # according to our indexes
30
        top_idx_doctor = np.array([471, 1213, 1632, 1635, 1636, 2743, 2774, 3020, 4880, 7057,
31
                                   7146, 7213, 8282, 9619, 9899, 9914, 10079, 10319, 10479, 11629,
32
                                   12569, 13075, 13343, 13815, 15103, 15481, 15716, 17130, ])
33
34
        if opt.random_input:
35
            np.random.seed(opt.seed)
36
            top_idx = np.arange(m_rna.shape[1])
37
        else: # modify the following filenames accordingly
38
            if opt.top_type == 0:
39
                if opt.pfi:
40
                    with open('../shap_log_pfi.txt') as f:
41
                        lines = f.readlines()
42
                        top_idx = np.array(lines[2].split(' ')).astype(int)[-10:]
43
                else:
44
                    with open('../shap_log.txt') as f:
45
                        lines = f.readlines()
46
                        top_idx = np.array(lines[2].split(' ')).astype(int)[-10:]
47
            elif opt.top_type == 1:
48
                top_idx = top_idx_doctor
49
        np.random.seed(opt.seed)
50
        np.random.shuffle(top_idx)
51
        top_idx = top_idx[:opt.top_k]
52
        print(top_idx)
53
        m_rna = m_rna[:, top_idx]
54
    """first random: train and test sets"""
55
    indices = np.arange(m_rna.shape[0])
56
    np.random.seed(1)
57
    np.random.shuffle(indices)
58
    m_rna2 = m_rna[indices]
59
    label2 = label[indices]
60
    sample_id = sample_id[indices]
61
    categorical_label = np_utils.to_categorical(label2)
62
    """to save the data split strategy for other analysis"""
63
    # tofile = np.stack((sample_id.astype(str), label2.astype(str)), axis=1)
64
    # np.savetxt(X=tofile, fname=MODEL_DIR + "/sample_id.txt", delimiter=",", fmt='%s')
65
66
    m_rna_train = m_rna2[:-opt.test_size, ]
67
    m_rna_test = m_rna2[-opt.test_size:, ]
68
    categorical_label_train = categorical_label[:-opt.test_size, ]
69
    categorical_label_test = categorical_label[-opt.test_size:, ]
70
    label_train = label2[:-opt.test_size, ]
71
    label_test = label2[-opt.test_size:, ]
72
    sample_id_test = sample_id[-opt.test_size:, ]
73
74
    """pr sample operations"""
75
    pr_idx_train = np.array([i for i, e in enumerate(label_train) if e == 2 or e == 3])
76
    pr_idx_test = np.array([i for i, e in enumerate(label_test) if e == 2 or e == 3])
77
    pr_m_rna_train, pr_m_rna_test = m_rna_train[pr_idx_train], m_rna_test[pr_idx_test]
78
    pr_label_train, pr_label_test = label_train[pr_idx_train] - 2, label_test[pr_idx_test] - 2
79
80
    print('pr samples in training set:', len(pr_idx_train))
81
    print('pr samples in testing set:', len(pr_idx_test))
82
    print('healthy samples in training set:', sum([1 for x in pr_label_train if x == 1]))
83
    print('healthy samples in testing set:', sum([1 for x in pr_label_test if x == 1]))
84
    print('PR train and test size:', pr_m_rna_train.shape, pr_m_rna_test.shape)
85
    pr_categorical_label_train = np_utils.to_categorical(pr_label_train)
86
    pr_categorical_label_test = np_utils.to_categorical(pr_label_test)
87
88
    print("data loading has just been finished")
89
90
    def create_model():
91
        inputs = Input(shape=(m_rna.shape[1],), name="inputs")
92
        inputs_1 = Dense(8, activation="relu", name="inputs_1")(inputs)
93
        encoded = Dense(4, activation='relu', name='encoded')(inputs_1)
94
        cl_0 = Dense(units=pr_categorical_label_train.shape[1], activation="softmax", name="category")(encoded)
95
        m = Model(inputs=inputs, outputs=[cl_0])
96
97
        m.compile(optimizer='adam',
98
                     loss=[tf.keras.losses.CategoricalCrossentropy()],     # "cosine_similarity"],
99
                     loss_weights=[1],     # , 0.5
100
                     metrics={"category": "acc"})     # , "cl_disease": "acc"
101
        return m
102
103
    model = create_model()
104
    checkpoint_path = os.path.join(MODEL_DIR, 'my_model.h5')
105
    # model.summary()
106
    if training:
107
        # file_writer = tf.summary.create_file_writer(MODEL_DIR + "/metrics")
108
        # file_writer.set_as_default()
109
        def lr_scheduler(epoch):
110
            lr = 0.01
111
            if epoch < 200:
112
                lr *= 0.99999999
113
            else:
114
                lr *= 0.99999
115
            tf.summary.scalar('Learning Rate', data=lr, step=epoch)
116
            return lr
117
118
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=MODEL_DIR)
119
        # lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
120
        model.fit(pr_m_rna_train, pr_categorical_label_train, batch_size=opt.batch_size,
121
                  epochs=opt.max_epoch,
122
                  callbacks=[tensorboard_callback],
123
                  validation_data=(pr_m_rna_test, pr_categorical_label_test),
124
                  verbose=2)
125
        model.save_weights(filepath=checkpoint_path)
126
        print("fitting has just been finished")
127
128
    else:
129
        model.load_weights(checkpoint_path)
130
        m_rna_test = pr_m_rna_test
131
        categorical_label_test = pr_categorical_label_test
132
        data_pred = model.predict(m_rna_test, batch_size=opt.batch_size, verbose=2)
133
        np.savetxt(X=m_rna_test, fname=MODEL_DIR + "/test_gene.csv", delimiter=",", fmt='%1.3f')
134
        np.savetxt(X=label, fname=MODEL_DIR + "/label.csv", delimiter=",", fmt='%1.3f')
135
        np.savetxt(X=data_pred, fname=MODEL_DIR + "/pred_label.csv", delimiter=",", fmt='%1.3f')
136
        """ argmax """
137
        y_pred = np.argmax(data_pred, axis=1)
138
        y_gt = np.argmax(categorical_label_test, axis=1)
139
        confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1])
140
        print(confusion_0)
141
        balanced_acc = sk.balanced_accuracy_score(y_gt, y_pred)
142
        """ logits for roc """
143
        pred_logit = data_pred[:, 0]
144
        gt_logit = categorical_label_test[:, 0]
145
        """ log """
146
        log1 = open(os.path.join(MODEL_DIR, 'log_blacc.txt'), 'a')
147
        log2 = open(os.path.join(MODEL_DIR, 'log_roc.txt'), 'a')
148
149
        roc_feat = ''
150
        """Only execute when using PR data (two labels)"""
151
        auc = sk.roc_auc_score(gt_logit, pred_logit)
152
        fpr, tpr, thresh = sk.roc_curve(gt_logit, pred_logit)
153
        roc_feat = {'auc': auc, 'fpr': [], 'tpr': []}
154
        for e in fpr:
155
            roc_feat['fpr'].append(str(e))
156
        for e in tpr:
157
            roc_feat['tpr'].append(str(e))
158
159
        def log_string(out1, out2):
160
            log1.write(str(out1))
161
            log1.write('\n')
162
            log1.flush()
163
            print(out1)
164
            if out2:
165
                log2.write(str(out2['auc']))
166
                log2.write('\n')
167
                roc_x, roc_y = ' '.join(out2['fpr']), ' '.join(out2['tpr'])
168
                log2.write(roc_x)
169
                log2.write('\n')
170
                log2.write(roc_y)
171
                log2.write('\n')
172
                log2.flush()
173
                print(out2)
174
175
        if opt.use_argmax:
176
            if opt.use_all:
177
                confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 36)))
178
                log_string(confusion_1, None)
179
            else:
180
                confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 4)))
181
                log_string(confusion_1, roc_feat)
182
        else:
183
            log_string(balanced_acc, roc_feat)
184
185
186
if __name__ == '__main__':
187
    if opt.phase == 'train':
188
        if not os.path.exists(os.path.join(MODEL_DIR, 'code/')):
189
            os.makedirs(os.path.join(MODEL_DIR, 'code/'))
190
            os.system('cp -r * %s' % (os.path.join(MODEL_DIR, 'code/')))  # bkp of model def
191
        eval_encoder(True)
192
    elif opt.phase == 'test':
193
        eval_encoder(False)