|
a |
|
b/code/eval-classifier.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
from tensorflow.keras.layers import Input, Dense, BatchNormalization, GaussianNoise, GaussianDropout, Conv1D, multiply |
|
|
3 |
from tensorflow.keras.models import Model |
|
|
4 |
import tensorflow.keras.backend as backend |
|
|
5 |
import numpy as np |
|
|
6 |
from sklearn.preprocessing import LabelEncoder, normalize |
|
|
7 |
import sklearn.metrics as sk |
|
|
8 |
from sklearn.metrics import mean_absolute_error as mae |
|
|
9 |
from sklearn.metrics import mean_squared_error as mse |
|
|
10 |
import tensorflow.keras.utils as np_utils |
|
|
11 |
from tensorflow.keras.callbacks import CSVLogger, History |
|
|
12 |
import data_provider |
|
|
13 |
import os |
|
|
14 |
from options import opt, MODEL_DIR |
|
|
15 |
import time |
|
|
16 |
import shap |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def eval_encoder(training=False): |
|
|
20 |
load_file = '../data_process/tcga_pfi.h5' if opt.pfi else '../data_process/tcga.h5' |
|
|
21 |
m_rna, label, gene, sample_id = data_provider.load_h5_all(load_file, True) |
|
|
22 |
if not opt.pfi: |
|
|
23 |
m_rna_o, label_o, gene_o, sample_id_o = data_provider.load_h5_all('../data_process/other.h5', True) |
|
|
24 |
m_rna, label, sample_id = np.concatenate((m_rna, m_rna_o)), np.concatenate((label, label_o)), np.concatenate((sample_id, sample_id_o)) |
|
|
25 |
m_rna = normalize(X=m_rna, axis=0, norm="max") |
|
|
26 |
print('Data size:', m_rna.shape, label) |
|
|
27 |
|
|
|
28 |
if not opt.use_all: |
|
|
29 |
# according to our indexes |
|
|
30 |
top_idx_doctor = np.array([471, 1213, 1632, 1635, 1636, 2743, 2774, 3020, 4880, 7057, |
|
|
31 |
7146, 7213, 8282, 9619, 9899, 9914, 10079, 10319, 10479, 11629, |
|
|
32 |
12569, 13075, 13343, 13815, 15103, 15481, 15716, 17130, ]) |
|
|
33 |
|
|
|
34 |
if opt.random_input: |
|
|
35 |
np.random.seed(opt.seed) |
|
|
36 |
top_idx = np.arange(m_rna.shape[1]) |
|
|
37 |
else: # modify the following filenames accordingly |
|
|
38 |
if opt.top_type == 0: |
|
|
39 |
if opt.pfi: |
|
|
40 |
with open('../shap_log_pfi.txt') as f: |
|
|
41 |
lines = f.readlines() |
|
|
42 |
top_idx = np.array(lines[2].split(' ')).astype(int)[-10:] |
|
|
43 |
else: |
|
|
44 |
with open('../shap_log.txt') as f: |
|
|
45 |
lines = f.readlines() |
|
|
46 |
top_idx = np.array(lines[2].split(' ')).astype(int)[-10:] |
|
|
47 |
elif opt.top_type == 1: |
|
|
48 |
top_idx = top_idx_doctor |
|
|
49 |
np.random.seed(opt.seed) |
|
|
50 |
np.random.shuffle(top_idx) |
|
|
51 |
top_idx = top_idx[:opt.top_k] |
|
|
52 |
print(top_idx) |
|
|
53 |
m_rna = m_rna[:, top_idx] |
|
|
54 |
"""first random: train and test sets""" |
|
|
55 |
indices = np.arange(m_rna.shape[0]) |
|
|
56 |
np.random.seed(1) |
|
|
57 |
np.random.shuffle(indices) |
|
|
58 |
m_rna2 = m_rna[indices] |
|
|
59 |
label2 = label[indices] |
|
|
60 |
sample_id = sample_id[indices] |
|
|
61 |
categorical_label = np_utils.to_categorical(label2) |
|
|
62 |
"""to save the data split strategy for other analysis""" |
|
|
63 |
# tofile = np.stack((sample_id.astype(str), label2.astype(str)), axis=1) |
|
|
64 |
# np.savetxt(X=tofile, fname=MODEL_DIR + "/sample_id.txt", delimiter=",", fmt='%s') |
|
|
65 |
|
|
|
66 |
m_rna_train = m_rna2[:-opt.test_size, ] |
|
|
67 |
m_rna_test = m_rna2[-opt.test_size:, ] |
|
|
68 |
categorical_label_train = categorical_label[:-opt.test_size, ] |
|
|
69 |
categorical_label_test = categorical_label[-opt.test_size:, ] |
|
|
70 |
label_train = label2[:-opt.test_size, ] |
|
|
71 |
label_test = label2[-opt.test_size:, ] |
|
|
72 |
sample_id_test = sample_id[-opt.test_size:, ] |
|
|
73 |
|
|
|
74 |
"""pr sample operations""" |
|
|
75 |
pr_idx_train = np.array([i for i, e in enumerate(label_train) if e == 2 or e == 3]) |
|
|
76 |
pr_idx_test = np.array([i for i, e in enumerate(label_test) if e == 2 or e == 3]) |
|
|
77 |
pr_m_rna_train, pr_m_rna_test = m_rna_train[pr_idx_train], m_rna_test[pr_idx_test] |
|
|
78 |
pr_label_train, pr_label_test = label_train[pr_idx_train] - 2, label_test[pr_idx_test] - 2 |
|
|
79 |
|
|
|
80 |
print('pr samples in training set:', len(pr_idx_train)) |
|
|
81 |
print('pr samples in testing set:', len(pr_idx_test)) |
|
|
82 |
print('healthy samples in training set:', sum([1 for x in pr_label_train if x == 1])) |
|
|
83 |
print('healthy samples in testing set:', sum([1 for x in pr_label_test if x == 1])) |
|
|
84 |
print('PR train and test size:', pr_m_rna_train.shape, pr_m_rna_test.shape) |
|
|
85 |
pr_categorical_label_train = np_utils.to_categorical(pr_label_train) |
|
|
86 |
pr_categorical_label_test = np_utils.to_categorical(pr_label_test) |
|
|
87 |
|
|
|
88 |
print("data loading has just been finished") |
|
|
89 |
|
|
|
90 |
def create_model(): |
|
|
91 |
inputs = Input(shape=(m_rna.shape[1],), name="inputs") |
|
|
92 |
inputs_1 = Dense(8, activation="relu", name="inputs_1")(inputs) |
|
|
93 |
encoded = Dense(4, activation='relu', name='encoded')(inputs_1) |
|
|
94 |
cl_0 = Dense(units=pr_categorical_label_train.shape[1], activation="softmax", name="category")(encoded) |
|
|
95 |
m = Model(inputs=inputs, outputs=[cl_0]) |
|
|
96 |
|
|
|
97 |
m.compile(optimizer='adam', |
|
|
98 |
loss=[tf.keras.losses.CategoricalCrossentropy()], # "cosine_similarity"], |
|
|
99 |
loss_weights=[1], # , 0.5 |
|
|
100 |
metrics={"category": "acc"}) # , "cl_disease": "acc" |
|
|
101 |
return m |
|
|
102 |
|
|
|
103 |
model = create_model() |
|
|
104 |
checkpoint_path = os.path.join(MODEL_DIR, 'my_model.h5') |
|
|
105 |
# model.summary() |
|
|
106 |
if training: |
|
|
107 |
# file_writer = tf.summary.create_file_writer(MODEL_DIR + "/metrics") |
|
|
108 |
# file_writer.set_as_default() |
|
|
109 |
def lr_scheduler(epoch): |
|
|
110 |
lr = 0.01 |
|
|
111 |
if epoch < 200: |
|
|
112 |
lr *= 0.99999999 |
|
|
113 |
else: |
|
|
114 |
lr *= 0.99999 |
|
|
115 |
tf.summary.scalar('Learning Rate', data=lr, step=epoch) |
|
|
116 |
return lr |
|
|
117 |
|
|
|
118 |
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=MODEL_DIR) |
|
|
119 |
# lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler) |
|
|
120 |
model.fit(pr_m_rna_train, pr_categorical_label_train, batch_size=opt.batch_size, |
|
|
121 |
epochs=opt.max_epoch, |
|
|
122 |
callbacks=[tensorboard_callback], |
|
|
123 |
validation_data=(pr_m_rna_test, pr_categorical_label_test), |
|
|
124 |
verbose=2) |
|
|
125 |
model.save_weights(filepath=checkpoint_path) |
|
|
126 |
print("fitting has just been finished") |
|
|
127 |
|
|
|
128 |
else: |
|
|
129 |
model.load_weights(checkpoint_path) |
|
|
130 |
m_rna_test = pr_m_rna_test |
|
|
131 |
categorical_label_test = pr_categorical_label_test |
|
|
132 |
data_pred = model.predict(m_rna_test, batch_size=opt.batch_size, verbose=2) |
|
|
133 |
np.savetxt(X=m_rna_test, fname=MODEL_DIR + "/test_gene.csv", delimiter=",", fmt='%1.3f') |
|
|
134 |
np.savetxt(X=label, fname=MODEL_DIR + "/label.csv", delimiter=",", fmt='%1.3f') |
|
|
135 |
np.savetxt(X=data_pred, fname=MODEL_DIR + "/pred_label.csv", delimiter=",", fmt='%1.3f') |
|
|
136 |
""" argmax """ |
|
|
137 |
y_pred = np.argmax(data_pred, axis=1) |
|
|
138 |
y_gt = np.argmax(categorical_label_test, axis=1) |
|
|
139 |
confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1]) |
|
|
140 |
print(confusion_0) |
|
|
141 |
balanced_acc = sk.balanced_accuracy_score(y_gt, y_pred) |
|
|
142 |
""" logits for roc """ |
|
|
143 |
pred_logit = data_pred[:, 0] |
|
|
144 |
gt_logit = categorical_label_test[:, 0] |
|
|
145 |
""" log """ |
|
|
146 |
log1 = open(os.path.join(MODEL_DIR, 'log_blacc.txt'), 'a') |
|
|
147 |
log2 = open(os.path.join(MODEL_DIR, 'log_roc.txt'), 'a') |
|
|
148 |
|
|
|
149 |
roc_feat = '' |
|
|
150 |
"""Only execute when using PR data (two labels)""" |
|
|
151 |
auc = sk.roc_auc_score(gt_logit, pred_logit) |
|
|
152 |
fpr, tpr, thresh = sk.roc_curve(gt_logit, pred_logit) |
|
|
153 |
roc_feat = {'auc': auc, 'fpr': [], 'tpr': []} |
|
|
154 |
for e in fpr: |
|
|
155 |
roc_feat['fpr'].append(str(e)) |
|
|
156 |
for e in tpr: |
|
|
157 |
roc_feat['tpr'].append(str(e)) |
|
|
158 |
|
|
|
159 |
def log_string(out1, out2): |
|
|
160 |
log1.write(str(out1)) |
|
|
161 |
log1.write('\n') |
|
|
162 |
log1.flush() |
|
|
163 |
print(out1) |
|
|
164 |
if out2: |
|
|
165 |
log2.write(str(out2['auc'])) |
|
|
166 |
log2.write('\n') |
|
|
167 |
roc_x, roc_y = ' '.join(out2['fpr']), ' '.join(out2['tpr']) |
|
|
168 |
log2.write(roc_x) |
|
|
169 |
log2.write('\n') |
|
|
170 |
log2.write(roc_y) |
|
|
171 |
log2.write('\n') |
|
|
172 |
log2.flush() |
|
|
173 |
print(out2) |
|
|
174 |
|
|
|
175 |
if opt.use_argmax: |
|
|
176 |
if opt.use_all: |
|
|
177 |
confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 36))) |
|
|
178 |
log_string(confusion_1, None) |
|
|
179 |
else: |
|
|
180 |
confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 4))) |
|
|
181 |
log_string(confusion_1, roc_feat) |
|
|
182 |
else: |
|
|
183 |
log_string(balanced_acc, roc_feat) |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
if __name__ == '__main__': |
|
|
187 |
if opt.phase == 'train': |
|
|
188 |
if not os.path.exists(os.path.join(MODEL_DIR, 'code/')): |
|
|
189 |
os.makedirs(os.path.join(MODEL_DIR, 'code/')) |
|
|
190 |
os.system('cp -r * %s' % (os.path.join(MODEL_DIR, 'code/'))) # bkp of model def |
|
|
191 |
eval_encoder(True) |
|
|
192 |
elif opt.phase == 'test': |
|
|
193 |
eval_encoder(False) |