--- a +++ b/code/mlc-ae.py @@ -0,0 +1,229 @@ +import tensorflow as tf +from tensorflow.keras.layers import Input, Dense, BatchNormalization, GaussianNoise, GaussianDropout, Conv1D, multiply +from tensorflow.keras.models import Model +import tensorflow.keras.backend as backend +import numpy as np +from sklearn.preprocessing import LabelEncoder, normalize, RobustScaler +import sklearn.metrics as sk +from sklearn.metrics import mean_absolute_error as mae +from sklearn.metrics import mean_squared_error as mse +import tensorflow.keras.utils as np_utils +from tensorflow.keras.callbacks import CSVLogger, History +import data_provider +import os +from options import opt, MODEL_DIR +import time +import shap +import pandas as pd + + +def mlc_ae(training=False): + load_file = '../data_process/tcga_pfi.h5' if opt.pfi else '../data_process/tcga.h5' + m_rna, label, gene, sample_id = data_provider.load_h5_all(load_file, True) + if not opt.pfi: + m_rna_o, label_o, gene_o, sample_id_o = data_provider.load_h5_all('../data_process/other.h5', True) + m_rna, label, sample_id = np.concatenate((m_rna, m_rna_o)), np.concatenate((label, label_o)), \ + np.concatenate((sample_id, sample_id_o)) + m_rna = normalize(X=m_rna, axis=0, norm="max") + print('feat and label size', m_rna.shape, label) + """first random: train and test sets""" + indices = np.arange(m_rna.shape[0]) + np.random.seed(1) + np.random.shuffle(indices) + m_rna2 = m_rna[indices] + label2 = label[indices] + sample_id = sample_id[indices] + if opt.use_all: + categorical_label = np_utils.to_categorical(label2, num_classes=6) + else: + categorical_label = np_utils.to_categorical(label2, num_classes=2) + """to save the data split strategy for other analysis""" + # tofile = np.stack((sample_id.astype(str), label2.astype(str)), axis=1) + # np.savetxt(X=tofile, fname=MODEL_DIR + "/sample_id.txt", delimiter=",", fmt='%s') + + m_rna_train = m_rna2[:-opt.test_size, ] + m_rna_test = m_rna2[-opt.test_size:, ] + categorical_label_train = categorical_label[:-opt.test_size, ] + categorical_label_test = categorical_label[-opt.test_size:, ] + label_train = label2[:-opt.test_size, ] + label_test = label2[-opt.test_size:, ] + + # pr sample operations + pr_idx_train = np.array([i for i, e in enumerate(label_train) if e == 2 or e == 3]) + pr_idx_test = np.array([i for i, e in enumerate(label_test) if e == 2 or e == 3]) + print('healthy samples in training set:', sum([1 for x in label_train if x == 3])) + print('healthy samples in testing set:', sum([1 for x in label_test if x == 3])) + + pr_m_rna_train = m_rna_train[pr_idx_train] + pr_m_rna_test = m_rna_test[pr_idx_test] + pr_label_train = label_train[pr_idx_train] - 2 + pr_label_test = label_test[pr_idx_test] - 2 + + print('PR train and test size:', pr_label_train.shape, pr_idx_test.shape) + pr_categorical_label_train = np_utils.to_categorical(pr_label_train) + pr_categorical_label_test = np_utils.to_categorical(pr_label_test) + print("data loading has just been finished") + + def create_model(): + inputs = Input(shape=(m_rna.shape[1],), name="inputs") + inputs_0 = BatchNormalization(name="inputs_0")(inputs) + inputs_1 = Dense(1024, activation="relu", name="inputs_1")(inputs_0) + inputs_2 = BatchNormalization(name="inputs_2")(inputs_1) + inputs_3 = Dense(256, activation="relu", name="inputs_3")(inputs_2) + inputs_4 = BatchNormalization(name="inputs_4")(inputs_3) + encoded = Dense(units=12, activation='relu', name='encoded')(inputs_4) + inputs_5 = Dense(512, activation="relu", name="inputs_5")(encoded) + decoded_tcga = Dense(units=m_rna.shape[1], activation='linear', name="m_rna")(inputs_5) + if opt.use_all: + cl_0 = Dense(units=categorical_label_train.shape[1], activation="softmax", name="category")(encoded) + else: + cl_0 = Dense(units=pr_categorical_label_train.shape[1], activation="softmax", name="category")(encoded) + m = Model(inputs=inputs, outputs=[decoded_tcga, cl_0]) + m.compile(optimizer='adam', + loss=["mse", "cosine_similarity"], # "cosine_similarity"], + loss_weights=[0.001, 0.5], # , 0.5 + metrics={"m_rna": ["mae", "mse"], "category": "acc"}) # , "cl_disease": "acc" + + return m + + model = create_model() + checkpoint_path = os.path.join(MODEL_DIR, 'my_model.h5') + # model.summary() + if training: + # file_writer = tf.summary.create_file_writer(MODEL_DIR + "/metrics") + # file_writer.set_as_default() + def lr_scheduler(epoch): + lr = 0.01 + if epoch < 200: + lr *= 0.99999999 + else: + lr *= 0.99999 + tf.summary.scalar('Learning Rate', data=lr, step=epoch) + return lr + + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=MODEL_DIR) + # lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler) + if opt.use_all: + model.fit(m_rna_train, [m_rna_train, categorical_label_train], batch_size=opt.batch_size, + epochs=opt.max_epoch, + callbacks=[tensorboard_callback], + validation_data=(m_rna_test, [m_rna_test, categorical_label_test]), + verbose=2) + else: + model.fit(pr_m_rna_train, [pr_m_rna_train, pr_categorical_label_train], batch_size=opt.batch_size, + epochs=opt.max_epoch, + callbacks=[tensorboard_callback], + validation_data=(pr_m_rna_test, [pr_m_rna_test, pr_categorical_label_test]), + verbose=2) + model.save_weights(filepath=checkpoint_path) + print("fitting has just been finished") + + else: + model.load_weights(checkpoint_path) + if not opt.use_all: + m_rna_test = pr_m_rna_test + categorical_label_test = pr_categorical_label_test + data_pred = model.predict(m_rna_test, batch_size=opt.batch_size, verbose=2) + + """ argmax """ + y_pred = np.argmax(data_pred[1], axis=1) + y_gt = label_test # np.argmax(categorical_label_test, axis=1) + if opt.use_all: + confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1, 2, 3, 4, 5]) + else: + confusion_0 = sk.confusion_matrix(y_gt, y_pred, labels=[0, 1]) + print(confusion_0) + balanced_acc = sk.balanced_accuracy_score(y_gt, y_pred) + acc = sk.accuracy_score(y_gt, y_pred) + """ logits for roc """ + y_logit = data_pred[1][:, 0] + x_logit = categorical_label_test[:, 0] + """save for records""" + np.savetxt(X=m_rna_test, fname=MODEL_DIR + "/test_gene.csv", delimiter=",", fmt='%1.3f') + np.savetxt(X=label, fname=MODEL_DIR + "/label.csv", delimiter=",", fmt='%1.3f') + np.savetxt(X=data_pred[0], fname=MODEL_DIR + "/pred_gene.csv", delimiter=",", fmt='%1.3f') + np.savetxt(X=y_pred, fname=MODEL_DIR + "/pred_label.csv", delimiter=",", fmt='%1.3f') + """ get latent representation """ + layer_name = "encoded" + encoded_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output) + encoded_output = encoded_layer_model.predict(m_rna_test) + np.savetxt(X=encoded_output, fname=MODEL_DIR + "/latent_feat.csv", delimiter=",") + + """ log """ + log1 = open(os.path.join(MODEL_DIR, 'log_blacc.txt'), 'a') + log2 = open(os.path.join(MODEL_DIR, 'log_roc.txt'), 'a') + log3 = open(os.path.join(MODEL_DIR, 'log_acc.txt'), 'a') + + if not opt.use_all: + """Only execute when using PR data (two labels)""" + auc = sk.roc_auc_score(x_logit, y_logit) + fpr, tpr, thresh = sk.roc_curve(x_logit, y_logit) + roc_feat = {'auc': auc, 'fpr': [], 'tpr': []} + for e in fpr: + roc_feat['fpr'].append(str(e)) + for e in tpr: + roc_feat['tpr'].append(str(e)) + + if opt.use_shap: + """ Depth Explainer """ + print("Processing SHAP...") + model.load_weights(checkpoint_path) + from matplotlib import colors as plt_colors + layer_name = "category" + encoded_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output) + input_feat = m_rna_train if opt.sample_all == 'all' else pr_m_rna_train + e = shap.GradientExplainer(encoded_layer_model, input_feat) + # e = shap.DeepExplainer(encoded_layer_model, input_feat) + shap_values = e.shap_values(input_feat) + shap_load_gene = '../data_process/gene_pfi.csv' if opt.pfi else '../data_process/gene.csv' + feat_name = np.loadtxt(shap_load_gene, dtype=str, delimiter=",") #[:, 3] + class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))]) + print('class_inds', class_inds) + colors = np.array(['yellowgreen', 'palevioletred', 'lightcoral', 'mediumpurple', 'cornflowerblue', + 'orange'])[class_inds] + cmap = plt_colors.ListedColormap(colors) + shap.summary_plot(shap_values, input_feat, feature_names=feat_name, max_display=40, + plot_size=(12.0, 16.0, 2.0), plot_type='bar', + color=cmap, show=True, sort=True, + class_names=['Ovarian (T)', 'Ovarian (N)', 'Prostate (T)', 'Prostate (N)', 'Breast (T)', + 'Breast (N)']) + + def log_string(out1, out2, out3): + log1.write(str(out1)) + log1.write('\n') + log1.flush() + print(out1) + if out2: + log2.write(str(out2['auc'])) + log2.write('\n') + roc_x, roc_y = ' '.join(out2['fpr']), ' '.join(out2['tpr']) + log2.write(roc_x) + log2.write('\n') + log2.write(roc_y) + log2.write('\n') + log2.flush() + print(out2) + if out3: + log3.write(str(out3)) + log3.write('\n') + log3.flush() + + if opt.use_argmax: + if opt.use_all: + confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 36))) + log_string(confusion_1, None, None) + else: + confusion_1 = ' '.join(list(np.reshape(confusion_0.astype(str), 4))) + log_string(confusion_1, roc_feat, acc) + else: + log_string(balanced_acc, None, acc) + + +if __name__ == '__main__': + if opt.phase == 'train': + if not os.path.exists(os.path.join(MODEL_DIR, 'code/')): + os.makedirs(os.path.join(MODEL_DIR, 'code/')) + os.system('cp -r * %s' % (os.path.join(MODEL_DIR, 'code/'))) # bkp of model def + mlc_ae(True) + elif opt.phase == 'test': + mlc_ae(False)