--- a
+++ b/modules/prediction_model.py
@@ -0,0 +1,330 @@
+import pandas as pd
+import numpy as np
+import argparse
+import json
+
+#### GCN #################
+from GCN_transformer import *
+
+import os.path as osp
+import os
+
+import random
+
+import networkx as nx
+from scipy import sparse
+
+import torch
+import torch_geometric
+import torch.nn.functional as F
+from torch.nn import Linear, BCEWithLogitsLoss
+from torch_geometric import transforms as T
+from torch_geometric.data import Data, Dataset, InMemoryDataset
+from torch_geometric.datasets import PPI
+from torch_geometric.loader import DataLoader
+import torch_geometric.nn as geom_nn
+from torch.utils.data import random_split
+from torch_geometric.nn import GATConv, GraphConv, GCNConv
+
+from sklearn.metrics import accuracy_score
+from sklearn.metrics import roc_auc_score
+
+###########Dataset generation############
+import argparse
+import networkx as nx
+import numpy as np
+from scipy import sparse
+
+import torch
+import torch.nn.functional as F
+from torch_geometric import transforms as T
+from torch_geometric.data import Data, Dataset, InMemoryDataset
+from torch_geometric.loader import DataLoader
+from scipy.stats import zscore
+from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler, Normalizer, RobustScaler, LabelEncoder
+
+#======================================================================
+def seed_everything(seed = 42):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+    os.environ['PYTHONHASHSEED'] = str(seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+    #torch.set_deterministic(True)
+
+### Dataset generation
+def sparse_mx_to_torch_sparse_tensor(sparse_mx):
+    """Convert a scipy sparse matrix to a torch sparse tensor: necessary for 'edges -> coo' format conversion"""
+    sparse_mx = sparse_mx.tocoo().astype(np.float32)
+    indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.compat.long))
+    return indices
+
+def processing_topology(graph):
+    '''
+    input: edgeList (source, target)
+    output: coo format for GNN    # .tocoo() alone does not directly returns the coo format
+    '''
+    nodes = sorted(list(graph.nodes()))
+    adj_mx = np.array(nx.adjacency_matrix(graph, nodelist=nodes).todense())
+    edge_index = sparse_mx_to_torch_sparse_tensor(sparse.csr_matrix(adj_mx).tocoo())
+    return nodes, edge_index
+
+class AsthmaDataset(InMemoryDataset):
+    def __init__(self, root, data_list=None, transform=None):
+        self.data_list = data_list
+        super().__init__(root, transform)
+        self.data, self.slices = torch.load(self.processed_paths[0])
+    @property
+    def processed_file_names(self):
+        return 'data.pt'
+    def process(self):
+        torch.save(self.collate(self.data_list), self.processed_paths[0])
+
+if __name__=="__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--label",'-l',dest='label')
+    parser.add_argument("-t")
+    parser.add_argument("-m")
+    parser.add_argument("-p")
+    parser.add_argument("-clin")
+    parser.add_argument("-train_samples")
+    parser.add_argument("-test_samples")
+    parser.add_argument("-featureSelection")
+    parser.add_argument("-propOut1")
+    parser.add_argument("-propOut2")
+    parser.add_argument("-DEG")
+    parser.add_argument("-DEP")
+    parser.add_argument("-K",type=int)
+    parser.add_argument("-exp_name")
+    parser.add_argument("-nwk")
+    parser.add_argument('-random_seed',type=int,default=1)
+    args = parser.parse_args()
+  
+    torch.cuda.empty_cache()
+    random_seed = args.random_seed
+    seed_everything(random_seed+41)
+
+    if args.featureSelection=='ourBiomarker':
+        # Retrieve biomarker candidates
+        our_out1 = pd.read_csv(args.propOut1,sep='\t',names=['node','prop_score'],header=0).sort_values(by='prop_score',ascending=False)
+        our_out2 = pd.read_csv(args.propOut2,sep='\t',names=['node','prop_score'],header=0).sort_values(by='prop_score',ascending=False)
+        our_out_genes1 = our_out1.sort_values(by='prop_score',ascending=False)['node'].unique()
+        our_out_genes2 = our_out2.sort_values(by='prop_score',ascending=False)['node'].unique()
+        our_out_genes = list(set.union(set(our_out_genes1[:args.K]),set(our_out_genes2[:args.K])))
+        genes = our_out_genes
+
+    if args.featureSelection == 'DEG':
+        DEG_raw = pd.read_csv(args.DEG, sep='\t',names=['gene','adj.p-val','stats']).dropna()
+        DEGs = DEG_raw.sort_values(by='adj.p-val',ascending=True).loc[lambda x:x['adj.p-val']<0.05,:]['gene'].to_list() 
+        genes = DEGs    
+
+    if args.featureSelection == 'DEP':
+        DEP_raw = pd.read_csv(args.DEP, sep='\t',names=['gene','adj.p-val','stats']).dropna()
+        DEPs = DEP_raw.sort_values(by='adj.p-val',ascending=True).loc[lambda x:x['adj.p-val']<0.05,:]['gene'].to_list() 
+        genes = DEPs
+    
+    # Data preparation
+    label=args.label
+
+    transcriptome = pd.read_csv(args.t,sep='\t',index_col=0)
+    transcriptome.columns = transcriptome.columns.astype(int)
+
+    methylome= pd.read_csv(args.m,sep='\t',index_col=0)
+    methylome.columns = methylome.columns.astype(int)
+
+    proteome = pd.read_csv(args.p,sep='\t',index_col=0)
+    proteome.columns = proteome.columns.astype(int)
+
+    clinical_raw = pd.read_csv(args.clin,sep='\t',index_col=0)
+    dict_clinical = clinical_raw.reset_index().groupby(label)['SUBJNO'].apply(list).to_dict()
+    all_samples_clinical = {v for i in dict_clinical.values() for v in i}
+    samples_common_omics = set.intersection(set(transcriptome.columns), set(methylome.columns), set(proteome.columns))
+    train_samples = set([int(l.strip()) for l in  open(args.train_samples).readlines()])
+    test_samples = set([int(l.strip()) for l in  open(args.test_samples).readlines()])
+    dict_clinical_r = clinical_raw.loc[:,label].to_dict()
+    all_samples_omics_clinical = samples_common_omics.intersection(all_samples_clinical)
+
+    def clinical_label(x,dict_):
+        if x in dict_[1]:
+            return 'low'
+        elif x in dict_[2]:
+            return 'high'
+        else:
+            return 'None'
+   
+    # normalize input data
+    scaler = MinMaxScaler() 
+    
+    transcriptome_filt_raw = transcriptome.T.loc[list(set(transcriptome.T.index) & set(all_samples_clinical)),
+                                                 transcriptome.T.columns.intersection(genes)]
+    methylome_filt_raw = methylome.T.loc[list(set(methylome.T.index) & set(all_samples_clinical)),
+                                         methylome.T.columns.intersection(genes)]
+    proteome_filt_raw = proteome.T.loc[list(set(proteome.T.index) & set(all_samples_clinical)),
+                                       proteome.T.columns.intersection(genes)]
+    
+    transcriptome_filt = pd.DataFrame(scaler.fit_transform(transcriptome_filt_raw.T).T, 
+                                      index=transcriptome_filt_raw.index, 
+                                      columns=transcriptome_filt_raw.columns)
+    methylome_filt = pd.DataFrame(scaler.fit_transform(methylome_filt_raw.T).T, 
+                                  index=methylome_filt_raw.index, 
+                                  columns=methylome_filt_raw.columns)
+    proteome_filt = pd.DataFrame(scaler.fit_transform(proteome_filt_raw.T).T, 
+                                 index=proteome_filt_raw.index, 
+                                 columns=proteome_filt_raw.columns)
+    
+    ########### Dataset generation for GCN #################
+    nwk = pd.read_csv(args.nwk,sep='\t',names=['g1','g2'])
+    G = nx.from_pandas_edgelist(nwk, source='g1', target='g2') 
+    
+    subgraph = G.subgraph(genes)
+
+    lcc_nodes = max(nx.connected_components(subgraph), key=len)
+    subgraph = subgraph.subgraph(lcc_nodes)
+    
+    nodes, edge_index = processing_topology(subgraph)
+  
+    def imputed_per_group(df):
+        global train_samples, dict_clinical, genes
+        group1 = df.loc[list(set.intersection(set(df.index), set(train_samples), set(dict_clinical[1]))),:].mode().iloc[0,:].median()
+        group2 = df.loc[list(set.intersection(set(df.index), set(train_samples), set(dict_clinical[2]))),:].mode().iloc[0,:].median()
+        return group1, group2 
+   
+    def imputed(df):
+        global train_samples
+        median = df.loc[list(set.intersection(set(df.index), set(train_samples))),:].mode().iloc[0,:].median() 
+        return median
+
+    tr_1, tr_2 = imputed_per_group(transcriptome_filt)
+    m_1, m_2 = imputed_per_group(methylome_filt)
+    p_1, p_2 = imputed_per_group(proteome_filt)
+
+    tr_ = imputed(transcriptome_filt)
+    m_ = imputed(methylome_filt)
+    p_ = imputed(proteome_filt)
+
+    data_list_train = []
+    data_list_test = []
+
+    
+    for sample in all_samples_clinical:                                                           
+        if sample in train_samples:
+            x_tmp = []
+            for gene in nodes:
+                if (gene in transcriptome_filt.columns) and (sample in transcriptome_filt.index):
+                    a = transcriptome_filt.loc[sample,gene]
+                else:
+                    a = np.full(1,tr_)
+                if (gene in methylome_filt.columns) and (sample in methylome_filt.index):
+                    b = methylome_filt.loc[sample,gene]
+                else:
+                    b = np.full(1,m_)
+                if (gene in proteome_filt.columns) and (sample in proteome_filt.index):
+                    c = proteome_filt.loc[sample,gene]
+                else:
+                    c = np.full(1,p_)
+                all_data = list(np.c_[a,b,c])
+                x_tmp.append(all_data)
+            x_tmp_tensor = torch.tensor(np.array(x_tmp,dtype=np.float32)).view(-1,3)
+            if dict_clinical_r[sample]==1:
+                data = Data(x=x_tmp_tensor, y=torch.tensor([0]), edge_index=edge_index)
+            if dict_clinical_r[sample]==2:
+                data = Data(x=x_tmp_tensor, y=torch.tensor([1]), edge_index=edge_index)
+            data_list_train.append(data)
+        if sample in test_samples:
+            x_tmp = []
+            for gene in nodes:
+                if (gene in transcriptome_filt.columns) and (sample in transcriptome_filt.index):
+                    a = transcriptome_filt.loc[sample,gene]
+                else:
+                    a = np.full(1,tr_)
+                if (gene in methylome_filt.columns) and (sample in methylome_filt.index):
+                    b = methylome_filt.loc[sample,gene]
+                else:
+                    b = np.full(1,m_)
+                if (gene in proteome_filt.columns) and (sample in proteome_filt.index):
+                    c = proteome_filt.loc[sample,gene]
+                else:
+                    c = np.full(1,p_)
+                all_data = list(np.c_[a,b,c])
+                x_tmp.append(all_data)
+            x_tmp_tensor = torch.tensor(np.array(x_tmp,dtype=np.float32)).view(-1,3)
+            if dict_clinical_r[sample]==1:
+                data = Data(x=x_tmp_tensor, y=torch.tensor([0]), edge_index=edge_index)
+            if dict_clinical_r[sample]==2:
+                data = Data(x=x_tmp_tensor, y=torch.tensor([1]), edge_index=edge_index)          
+            data_list_test.append(data)                                                         
+
+    Dataset_name = "Dataset_minmaxSample_{}".format(args.exp_name)
+    Asthma_train = AsthmaDataset(Dataset_name, data_list_train)
+    Asthma_test = AsthmaDataset(Dataset_name+'.test', data_list_test)
+
+
+    ### Train model
+    VALID_RATIO = 0.8
+    
+    g = torch.Generator()
+    g.manual_seed(torch.initial_seed())
+    n_train_examples = int(len(Asthma_train) * VALID_RATIO)
+    n_valid_examples = len(Asthma_train) - n_train_examples
+
+    def stratified_split(dataset):
+        global VALID_RATIO,random_seed
+        from sklearn.model_selection import train_test_split
+        labels=[data.y.item() for data in dataset]
+        train_indices, val_indices = train_test_split(list(range(len(labels))),train_size=VALID_RATIO,shuffle=True,stratify=labels,random_state=(random_seed+42))
+
+        train_dataset = torch.utils.data.Subset(dataset, train_indices)
+        val_dataset = torch.utils.data.Subset(dataset, val_indices)
+        return train_dataset, val_dataset
+     
+    graph_train_data, graph_valid_data = stratified_split(Asthma_train)
+
+    def seed_worker(worker_id):
+        worker_seed = torch.initial_seed() % 2**32
+        numpy.random.seed(worker_seed)
+        random.seed(worker_seed)
+    
+    BATCH_SIZE = 10
+    graph_train_loader = torch_geometric.loader.DataLoader(graph_train_data,shuffle=True,batch_size=BATCH_SIZE,worker_init_fn=seed_worker,generator=g,num_workers=0) 
+    graph_val_loader = torch_geometric.loader.DataLoader(graph_valid_data,shuffle=True,batch_size=BATCH_SIZE,worker_init_fn=seed_worker,generator=g,num_workers=0)
+    graph_test_loader = torch_geometric.loader.DataLoader(Asthma_test,batch_size=1,worker_init_fn=seed_worker,generator=g,num_workers=0)
+    
+    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    print(args.device)
+    args.epochs = 40
+    args.test = True
+    args.learning_rate =  0.001
+    args.batch_size = BATCH_SIZE
+    args.weight_decay = 0 
+    args.dropout_rate =  0.2
+    device = args.device
+
+    model, best_performances, test_loss, test_acc = experiment_graph(args, graph_train_loader, graph_val_loader, graph_test_loader)
+    
+    y_li , true_y_li = [],[]
+    for data in graph_test_loader:
+        data = data.to(device)
+        out, att_idx, att_w = model(data)
+        
+        y = out.cpu().detach().flatten().tolist()
+        true_y = data.y.cpu().detach().flatten().tolist()
+    
+        y_li.extend(y)
+        true_y_li.extend(true_y)
+    
+    from sklearn import metrics
+    fpr,tpr,thres_roc = metrics.roc_curve(true_y_li,y_li,pos_label=1)
+    precision,recall,thres_pr = metrics.precision_recall_curve(true_y_li,y_li,pos_label=1)
+    
+    auprc = metrics.auc(recall,precision)
+    auroc = metrics.auc(fpr,tpr)
+    from collections import Counter
+    with open(args.exp_name+".performance.txt",'w') as f:
+        print("AUROC: {:.5f}".format(auroc),file=f)
+        print("AUPRC: {:.5f} \t baseline: {:.5f}".format(auprc, Counter(true_y_li)[1]/len(true_y_li)), file=f)
+    print("==================")
+    file_name = args.exp_name + '.TransformerConv.best_model'
+    torch.save(model.state_dict(),file_name)