a b/src/run_LRP.py
1
from time import time
2
import math, os
3
from sklearn import metrics
4
from sklearn.cluster import KMeans
5
import torch
6
import torch.nn as nn
7
from torch.autograd import Variable
8
from torch.nn import Parameter
9
import torch.nn.functional as F
10
import torch.optim as optim
11
from torch.utils.data import DataLoader, TensorDataset
12
13
from scMDC import scMultiCluster
14
import numpy as np
15
import collections
16
import h5py
17
import scanpy as sc
18
from preprocess import read_dataset, normalize, clr_normalize_each_cell
19
from utils import *
20
from functools import reduce
21
from LRP import LRP
22
23
24
if __name__ == "__main__":
25
26
    # setting the hyper parameters
27
    import argparse
28
    parser = argparse.ArgumentParser(description='train',
29
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30
    parser.add_argument('--n_clusters', default=8, type=int)
31
    parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
32
    parser.add_argument('--batch_size', default=256, type=int)
33
    parser.add_argument('--data_file', default='Simulation.1.h5')
34
    parser.add_argument('--cluster_index_file', default='label.txt')
35
    parser.add_argument('--maxiter', default=10000, type=int)
36
    parser.add_argument('--pretrain_epochs', default=400, type=int)
37
    parser.add_argument('--gamma', default=.1, type=float,
38
                        help='coefficient of clustering loss')
39
    parser.add_argument('--tau', default=1., type=float,
40
                        help='fuzziness of clustering loss')       
41
    parser.add_argument('--phi1', default=0.001, type=float,
42
                        help='coefficient of KL loss in pretraining stage')
43
    parser.add_argument('--phi2', default=0.001, type=float,
44
                        help='coefficient of KL loss in clustering stage')
45
    parser.add_argument('--update_interval', default=1, type=int)
46
    parser.add_argument('--tol', default=0.001, type=float)
47
    parser.add_argument('--ae_weights', default=None)
48
    parser.add_argument('--save_dir', default='results/')
49
    parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
50
    parser.add_argument('--resolution', default=0.2, type=float)
51
    parser.add_argument('--n_neighbors', default=30, type=int)
52
    parser.add_argument('--embedding_file', action='store_true', default=False)
53
    parser.add_argument('--prediction_file', action='store_true', default=False)
54
    parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
55
    parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
56
    parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
57
    parser.add_argument('--sigma1', default=2.5, type=float)
58
    parser.add_argument('--sigma2', default=1.5, type=float)
59
    parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
60
    parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
61
    parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
62
    parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
63
    parser.add_argument('--run', default=1, type=int)
64
    parser.add_argument('--beta', default=1., type=float,
65
                        help='coefficient of the clustering fuzziness')
66
    parser.add_argument('--margin', default=1., type=float,
67
                        help='margin of difference between logits')
68
    parser.add_argument('--lamda', default=100., type=float,
69
                        help='coefficient of the clustering perturbation loss')
70
    parser.add_argument('--lr', default=0.001, type=int)
71
    parser.add_argument('--device', default='cuda')
72
    
73
    args = parser.parse_args()
74
    print(args)
75
    data_mat = h5py.File(args.data_file)
76
    x1 = np.array(data_mat['X1'])
77
    x2 = np.array(data_mat['X2'])
78
    #y = np.array(data_mat['Y']) - 1
79
    data_mat.close()
80
    
81
82
    clust_ids = np.loadtxt(args.cluster_index_file, delimiter=",").astype(int)
83
84
    #Gene features
85
    if args.filter1:
86
        importantGenes = geneSelection(x1, n=args.f1, plot=False)
87
        x1 = x1[:, importantGenes]
88
    if args.filter2:
89
        importantGenes = geneSelection(x2, n=args.f2, plot=False)
90
        x2 = x2[:, importantGenes]
91
92
    adata1 = sc.AnnData(x1)
93
    #adata1.obs['Group'] = y
94
95
    adata1 = read_dataset(adata1,
96
                     transpose=False,
97
                     test_split=False,
98
                     copy=True)
99
100
    adata1 = normalize(adata1,
101
                      size_factors=True,
102
                      normalize_input=True,
103
                      logtrans_input=True)
104
    
105
    adata2 = sc.AnnData(x2)
106
    #adata2.obs['Group'] = y
107
    adata2 = read_dataset(adata2,
108
                     transpose=False,
109
                     test_split=False,
110
                     copy=True)
111
    
112
    adata2 = normalize(adata2,
113
                      size_factors=True,
114
                      normalize_input=True,
115
                      logtrans_input=True)
116
117
    #adata2 = clr_normalize_each_cell(adata2)
118
119
    input_size1 = adata1.n_vars
120
    input_size2 = adata2.n_vars
121
    print(adata1.X.shape)
122
    print(adata2.X.shape)
123
124
    print(args)
125
    
126
    encodeLayer = list(map(int, args.encodeLayer))
127
    decodeLayer1 = list(map(int, args.decodeLayer1))
128
    decodeLayer2 = list(map(int, args.decodeLayer2))
129
    
130
    model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau,
131
                        encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
132
                        activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 
133
                        cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
134
    
135
    print(str(model))
136
137
    if os.path.isfile(args.ae_weights):
138
            print("==> loading checkpoint '{}'".format(args.ae_weights))
139
            checkpoint = torch.load(args.ae_weights)
140
            model.load_state_dict(checkpoint['ae_state_dict'])
141
    else:
142
        print("==> no checkpoint found at '{}'".format(args.ae_weights))
143
        raise ValueError
144
    
145
    n_clusters = np.unique(clust_ids).shape[0]
146
    print("n cluster is: " + str(n_clusters))
147
148
    Z = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device)).data.cpu().numpy()
149
    
150
    cluster_list = np.unique(clust_ids).astype(int).tolist()
151
    print(cluster_list)
152
153
    model_explainer = LRP(model, X1=adata1.X, X2=adata2.X, Z=Z, clust_ids=clust_ids, n_clusters=n_clusters, beta=args.beta).to(args.device)
154
    
155
    #for clust_c in [cluster_ind[0]]: #range(args.n_clusters):
156
    #    for clust_k in [cluster_ind[1]]: #range(clust_c+1, args.n_clusters):
157
    #        print("Cluster"+str(clust_c)+" vs Cluster"+str(clust_k))
158
    #        rel_score1, rel_score2  = model_explainer.calc_carlini_wagner_one_vs_one(clust_c, clust_k, margin=args.margin, lamda=args.lamda, max_iter=args.maxiter, lr=args.lr)
159
    #        print(rel_score1.shape)
160
    #        print(rel_score2.shape)
161
    #        np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_"+str(clust_k)+"_rel_mRNA_scores.csv", rel_score1, delimiter=",")
162
    #        np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_"+str(clust_k)+"_rel_ADT_scores.csv", rel_score2, delimiter=",")
163
164
    for clust_c in cluster_list:
165
        print("Cluster"+str(clust_c)+" vs Rest")
166
        rel_score1, rel_score2 = model_explainer.calc_carlini_wagner_one_vs_rest(clust_c, margin=args.margin, lamda=args.lamda, max_iter=args.maxiter, lr=args.lr)
167
        np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_rest_rel_mRNA_scores.csv", rel_score1, delimiter=",")
168
        np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_rest_rel_ADT_scores.csv", rel_score2, delimiter=",")