|
a |
|
b/code/mlc-ae.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
from tensorflow.keras.layers import Input, Dense, BatchNormalization, GaussianNoise, GaussianDropout, Conv1D, multiply |
|
|
3 |
from tensorflow.keras.models import Model |
|
|
4 |
import tensorflow.keras.backend as backend |
|
|
5 |
import numpy as np |
|
|
6 |
from sklearn.preprocessing import LabelEncoder, normalize, RobustScaler |
|
|
7 |
import sklearn.metrics as sk |
|
|
8 |
from sklearn.metrics import mean_absolute_error as mae |
|
|
9 |
from sklearn.metrics import mean_squared_error as mse |
|
|
10 |
import tensorflow.keras.utils as np_utils |
|
|
11 |
from tensorflow.keras.callbacks import CSVLogger, History |
|
|
12 |
import data_provider |
|
|
13 |
import os |
|
|
14 |
from options import opt, MODEL_DIR |
|
|
15 |
import time |
|
|
16 |
import shap |
|
|
17 |
import pandas as pd |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def mlc_ae(training=False): |
|
|
21 |
load_file = '../data_process/tcga_pfi.h5' if opt.pfi else '../data_process/tcga.h5' |
|
|
22 |
m_rna, label, gene, sample_id = data_provider.load_h5_all(load_file, True) |
|
|
23 |
if not opt.pfi: |
|
|
24 |
m_rna_o, label_o, gene_o, sample_id_o = data_provider.load_h5_all('../data_process/other.h5', True) |
|
|
25 |
m_rna, label, sample_id = np.concatenate((m_rna, m_rna_o)), np.concatenate((label, label_o)), \ |
|
|
26 |
np.concatenate((sample_id, sample_id_o)) |
|
|
27 |
m_rna = normalize(X=m_rna, axis=0, norm="max") |
|
|
28 |
print('feat and label size', m_rna.shape, label) |
|
|
29 |
"""first random: train and test sets""" |
|
|
30 |
indices = np.arange(m_rna.shape[0]) |
|
|
31 |
np.random.seed(1) |
|
|
32 |
np.random.shuffle(indices) |
|
|
33 |
m_rna2 = m_rna[indices] |
|
|
34 |
label2 = label[indices] |
|
|
35 |
sample_id = sample_id[indices] |
|
|
36 |
if opt.use_all: |
|
|
37 |
categorical_label = np_utils.to_categorical(label2, num_classes=6) |
|
|
38 |
else: |
|
|
39 |
categorical_label = np_utils.to_categorical(label2, num_classes=2) |
|
|
40 |
"""to save the data split strategy for other analysis""" |
|
|
41 |
# tofile = np.stack((sample_id.astype(str), label2.astype(str)), axis=1) |
|
|
42 |
# np.savetxt(X=tofile, fname=MODEL_DIR + "/sample_id.txt", delimiter=",", fmt='%s') |
|
|
43 |
|
|
|
44 |
m_rna_train = m_rna2[:-opt.test_size, ] |
|
|
45 |
m_rna_test = m_rna2[-opt.test_size:, ] |
|
|
46 |
categorical_label_train = categorical_label[:-opt.test_size, ] |
|
|
47 |
categorical_label_test = categorical_label[-opt.test_size:, ] |
|
|
48 |
label_train = label2[:-opt.test_size, ] |
|
|
49 |
label_test = label2[-opt.test_size:, ] |
|
|
50 |
|
|
|
51 |
# pr sample operations |
|
|
52 |
pr_idx_train = np.array([i for i, e in enumerate(label_train) if e == 2 or e == 3]) |
|
|
53 |
pr_idx_test = np.array([i for i, e in enumerate(label_test) if e == 2 or e == 3]) |
|
|
54 |
print('healthy samples in training set:', sum([1 for x in label_train if x == 3])) |
|
|
55 |
print('healthy samples in testing set:', sum([1 for x in label_test if x == 3])) |
|
|
56 |
|
|
|
57 |
pr_m_rna_train = m_rna_train[pr_idx_train] |
|
|
58 |
pr_m_rna_test = m_rna_test[pr_idx_test] |
|
|
59 |
pr_label_train = label_train[pr_idx_train] - 2 |
|
|
60 |
pr_label_test = label_test[pr_idx_test] - 2 |
|
|
61 |
|
|
|
62 |
print('PR train and test size:', pr_label_train.shape, pr_idx_test.shape) |
|
|
63 |
pr_categorical_label_train = np_utils.to_categorical(pr_label_train) |
|
|
64 |
pr_categorical_label_test = np_utils.to_categorical(pr_label_test) |
|
|
65 |
print("data loading has just been finished") |
|
|
66 |
|
|
|
67 |
def create_model(): |
|
|
68 |
inputs = Input(shape=(m_rna.shape[1],), name="inputs") |
|
|
69 |
inputs_0 = BatchNormalization(name="inputs_0")(inputs) |
|
|
70 |
inputs_1 = Dense(1024, activation="relu", name="inputs_1")(inputs_0) |
|
|
71 |
inputs_2 = BatchNormalization(name="inputs_2")(inputs_1) |
|
|
72 |
inputs_3 = Dense(256, activation="relu", name="inputs_3")(inputs_2) |
|
|
73 |
inputs_4 = BatchNormalization(name="inputs_4")(inputs_3) |
|
|
74 |
encoded = Dense(units=12, activation='relu', name='encoded')(inputs_4) |
|
|
75 |
inputs_5 = Dense(512, activation="relu", name="inputs_5")(encoded) |
|
|
76 |
decoded_tcga = Dense(units=m_rna.shape[1], activation='linear', name="m_rna")(inputs_5) |
|
|
77 |
if opt.use_all: |
|
|
78 |
cl_0 = Dense(units=categorical_label_train.shape[1], activation="softmax", name="category")(encoded) |
|
|
79 |
else: |
|
|
80 |
cl_0 = Dense(units=pr_categorical_label_train.shape[1], activation="softmax", name="category")(encoded) |
|
|
81 |
m = Model(inputs=inputs, outputs=[decoded_tcga, cl_0]) |
|
|
82 |
m.compile(optimizer='adam', |
|
|
83 |
loss=["mse", "cosine_similarity"], # "cosine_similarity"], |
|
|
84 |
loss_weights=[0.001, 0.5], # , 0.5 |
|
|
85 |
metrics={"m_rna": ["mae", "mse"], "category": "acc"}) # , "cl_disease": "acc" |
|
|
86 |
|
|
|
87 |
return m |
|
|
88 |
|
|
|
89 |
model = create_model() |
|
|
90 |
checkpoint_path = os.path.join(MODEL_DIR, 'my_model.h5') |
|
|
91 |
# model.summary() |
|
|
92 |
if training: |
|
|
93 |
# file_writer = tf.summary.create_file_writer(MODEL_DIR + "/metrics") |
|
|
94 |
# file_writer.set_as_default() |
|
|
95 |
def lr_scheduler(epoch): |
|
|
96 |
lr = 0.01 |
|
|
97 |
if epoch < 200: |
|
|
98 |
lr *= 0.99999999 |
|
|
99 |
else: |
|
|
100 |
lr *= 0.99999 |
|
|
101 |
tf.summary.scalar('Learning Rate', data=lr, step=epoch) |
|
|
102 |
return lr |
|
|
103 |
|
|
|
104 |
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=MODEL_DIR) |
|
|
105 |
# lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler) |
|
|
106 |
if opt.use_all: |
|
|
107 |
model.fit(m_rna_train, [m_rna_train, categorical_label_train], batch_size=opt.batch_size, |
|
|
108 |
epochs=opt.max_epoch, |
|
|
109 |
callbacks=[tensorboard_callback], |
|
|
110 |
validation_data=(m_rna_test, [m_rna_test, categorical_label_test]), |
|
|
111 |
verbose=2) |
|
|
112 |
else: |
|
|
113 |
model.fit(pr_m_rna_train, [pr_m_rna_train, pr_categorical_label_train], batch_size=opt.batch_size, |
|
|
114 |
epochs=opt.max_epoch, |
|
|
115 |
callbacks=[tensorboard_callback], |
|
|
116 |
validation_data=(pr_m_rna_test, [pr_m_rna_test, pr_categorical_label_test]), |
|
|
117 |
verbose=2) |
|
|
118 |
model.save_weights(filepath=checkpoint_path) |
|
|
119 |
print("fitting has just been finished") |
|
|
120 |
|
|
|
121 |
else: |
|
|
122 |
model.load_weights(checkpoint_path) |
|
|
123 |
if not opt.use_all: |
|
|
124 |
m_rna_test = pr_m_rna_test |
|
|
125 |
categorical_label_test = pr_categorical_label_test |
|
|
126 |
data_pred = model.predict(m_rna_test, batch_size=opt.batch_size, verbose=2) |
|
|
127 |
|
|
|
128 |
""" argmax """ |
|
|
129 |
y_pred = np.argmax(data_pred[1], axis=1) |
|
|
130 |
y_gt = label_test # np.argmax(categorical_label_test, axis=1) |
|
|
131 |
if opt.use_all: |
|
|
132 |
confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1, 2, 3, 4, 5]) |
|
|
133 |
else: |
|
|
134 |
confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1]) |
|
|
135 |
print(confusion_0) |
|
|
136 |
balanced_acc = sk.balanced_accuracy_score(y_gt, y_pred) |
|
|
137 |
acc = sk.accuracy_score(y_gt, y_pred) |
|
|
138 |
""" logits for roc """ |
|
|
139 |
y_logit = data_pred[1][:, 0] |
|
|
140 |
x_logit = categorical_label_test[:, 0] |
|
|
141 |
"""save for records""" |
|
|
142 |
np.savetxt(X=m_rna_test, fname=MODEL_DIR + "/test_gene.csv", delimiter=",", fmt='%1.3f') |
|
|
143 |
np.savetxt(X=label, fname=MODEL_DIR + "/label.csv", delimiter=",", fmt='%1.3f') |
|
|
144 |
np.savetxt(X=data_pred[0], fname=MODEL_DIR + "/pred_gene.csv", delimiter=",", fmt='%1.3f') |
|
|
145 |
np.savetxt(X=y_pred, fname=MODEL_DIR + "/pred_label.csv", delimiter=",", fmt='%1.3f') |
|
|
146 |
""" get latent representation """ |
|
|
147 |
layer_name = "encoded" |
|
|
148 |
encoded_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output) |
|
|
149 |
encoded_output = encoded_layer_model.predict(m_rna_test) |
|
|
150 |
np.savetxt(X=encoded_output, fname=MODEL_DIR + "/latent_feat.csv", delimiter=",") |
|
|
151 |
|
|
|
152 |
""" log """ |
|
|
153 |
log1 = open(os.path.join(MODEL_DIR, 'log_blacc.txt'), 'a') |
|
|
154 |
log2 = open(os.path.join(MODEL_DIR, 'log_roc.txt'), 'a') |
|
|
155 |
log3 = open(os.path.join(MODEL_DIR, 'log_acc.txt'), 'a') |
|
|
156 |
|
|
|
157 |
if not opt.use_all: |
|
|
158 |
"""Only execute when using PR data (two labels)""" |
|
|
159 |
auc = sk.roc_auc_score(x_logit, y_logit) |
|
|
160 |
fpr, tpr, thresh = sk.roc_curve(x_logit, y_logit) |
|
|
161 |
roc_feat = {'auc': auc, 'fpr': [], 'tpr': []} |
|
|
162 |
for e in fpr: |
|
|
163 |
roc_feat['fpr'].append(str(e)) |
|
|
164 |
for e in tpr: |
|
|
165 |
roc_feat['tpr'].append(str(e)) |
|
|
166 |
|
|
|
167 |
if opt.use_shap: |
|
|
168 |
""" Depth Explainer """ |
|
|
169 |
print("Processing SHAP...") |
|
|
170 |
model.load_weights(checkpoint_path) |
|
|
171 |
from matplotlib import colors as plt_colors |
|
|
172 |
layer_name = "category" |
|
|
173 |
encoded_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output) |
|
|
174 |
input_feat = m_rna_train if opt.sample_all == 'all' else pr_m_rna_train |
|
|
175 |
e = shap.GradientExplainer(encoded_layer_model, input_feat) |
|
|
176 |
# e = shap.DeepExplainer(encoded_layer_model, input_feat) |
|
|
177 |
shap_values = e.shap_values(input_feat) |
|
|
178 |
shap_load_gene = '../data_process/gene_pfi.csv' if opt.pfi else '../data_process/gene.csv' |
|
|
179 |
feat_name = np.loadtxt(shap_load_gene, dtype=str, delimiter=",") #[:, 3] |
|
|
180 |
class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))]) |
|
|
181 |
print('class_inds', class_inds) |
|
|
182 |
colors = np.array(['yellowgreen', 'palevioletred', 'lightcoral', 'mediumpurple', 'cornflowerblue', |
|
|
183 |
'orange'])[class_inds] |
|
|
184 |
cmap = plt_colors.ListedColormap(colors) |
|
|
185 |
shap.summary_plot(shap_values, input_feat, feature_names=feat_name, max_display=40, |
|
|
186 |
plot_size=(12.0, 16.0, 2.0), plot_type='bar', |
|
|
187 |
color=cmap, show=True, sort=True, |
|
|
188 |
class_names=['Ovarian (T)', 'Ovarian (N)', 'Prostate (T)', 'Prostate (N)', 'Breast (T)', |
|
|
189 |
'Breast (N)']) |
|
|
190 |
|
|
|
191 |
def log_string(out1, out2, out3): |
|
|
192 |
log1.write(str(out1)) |
|
|
193 |
log1.write('\n') |
|
|
194 |
log1.flush() |
|
|
195 |
print(out1) |
|
|
196 |
if out2: |
|
|
197 |
log2.write(str(out2['auc'])) |
|
|
198 |
log2.write('\n') |
|
|
199 |
roc_x, roc_y = ' '.join(out2['fpr']), ' '.join(out2['tpr']) |
|
|
200 |
log2.write(roc_x) |
|
|
201 |
log2.write('\n') |
|
|
202 |
log2.write(roc_y) |
|
|
203 |
log2.write('\n') |
|
|
204 |
log2.flush() |
|
|
205 |
print(out2) |
|
|
206 |
if out3: |
|
|
207 |
log3.write(str(out3)) |
|
|
208 |
log3.write('\n') |
|
|
209 |
log3.flush() |
|
|
210 |
|
|
|
211 |
if opt.use_argmax: |
|
|
212 |
if opt.use_all: |
|
|
213 |
confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 36))) |
|
|
214 |
log_string(confusion_1, None, None) |
|
|
215 |
else: |
|
|
216 |
confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 4))) |
|
|
217 |
log_string(confusion_1, roc_feat, acc) |
|
|
218 |
else: |
|
|
219 |
log_string(balanced_acc, None, acc) |
|
|
220 |
|
|
|
221 |
|
|
|
222 |
if __name__ == '__main__': |
|
|
223 |
if opt.phase == 'train': |
|
|
224 |
if not os.path.exists(os.path.join(MODEL_DIR, 'code/')): |
|
|
225 |
os.makedirs(os.path.join(MODEL_DIR, 'code/')) |
|
|
226 |
os.system('cp -r * %s' % (os.path.join(MODEL_DIR, 'code/'))) # bkp of model def |
|
|
227 |
mlc_ae(True) |
|
|
228 |
elif opt.phase == 'test': |
|
|
229 |
mlc_ae(False) |