a b/src/run_scMDC.py
1
from time import time
2
import math, os
3
from sklearn import metrics
4
from sklearn.cluster import KMeans
5
import torch
6
import torch.nn as nn
7
from torch.autograd import Variable
8
from torch.nn import Parameter
9
import torch.nn.functional as F
10
import torch.optim as optim
11
from torch.utils.data import DataLoader, TensorDataset
12
13
from scMDC import scMultiCluster
14
import numpy as np
15
import collections
16
import h5py
17
import scanpy as sc
18
from preprocess import read_dataset, normalize, clr_normalize_each_cell
19
from utils import *
20
21
if __name__ == "__main__":
22
23
    # setting the hyper parameters
24
    import argparse
25
    parser = argparse.ArgumentParser(description='train',
26
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
    parser.add_argument('--n_clusters', default=27, type=int)
28
    parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
29
    parser.add_argument('--batch_size', default=256, type=int)
30
    parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5')
31
    parser.add_argument('--maxiter', default=5000, type=int)
32
    parser.add_argument('--pretrain_epochs', default=400, type=int)
33
    parser.add_argument('--gamma', default=.1, type=float,
34
                        help='coefficient of clustering loss')
35
    parser.add_argument('--tau', default=1., type=float,
36
                        help='fuzziness of clustering loss')                    
37
    parser.add_argument('--phi1', default=0.001, type=float,
38
                        help='coefficient of KL loss in pretraining stage')
39
    parser.add_argument('--phi2', default=0.001, type=float,
40
                        help='coefficient of KL loss in clustering stage')
41
    parser.add_argument('--update_interval', default=1, type=int)
42
    parser.add_argument('--tol', default=0.001, type=float)
43
    parser.add_argument('--lr', default=1., type=float)
44
    parser.add_argument('--ae_weights', default=None)
45
    parser.add_argument('--save_dir', default='results/')
46
    parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
47
    parser.add_argument('--resolution', default=0.2, type=float)
48
    parser.add_argument('--n_neighbors', default=30, type=int)
49
    parser.add_argument('--embedding_file', action='store_true', default=False)
50
    parser.add_argument('--prediction_file', action='store_true', default=False)
51
    parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
52
    parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
53
    parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
54
    parser.add_argument('--sigma1', default=2.5, type=float)
55
    parser.add_argument('--sigma2', default=1.5, type=float)
56
    parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
57
    parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
58
    parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
59
    parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
60
    parser.add_argument('--run', default=1, type=int)
61
    parser.add_argument('--device', default='cuda')
62
    parser.add_argument('--no_labels', action='store_true', default=False)
63
    args = parser.parse_args()
64
    print(args)
65
    
66
    data_mat = h5py.File(args.data_file)
67
    x1 = np.array(data_mat['X1'])
68
    x2 = np.array(data_mat['X2'])
69
    if not args.no_labels:
70
         y = np.array(data_mat['Y'])
71
    data_mat.close()
72
73
    #Gene filter
74
    if args.filter1:
75
        importantGenes = geneSelection(x1, n=args.f1, plot=False)
76
        x1 = x1[:, importantGenes]
77
    if args.filter2:
78
        importantGenes = geneSelection(x2, n=args.f2, plot=False)
79
        x2 = x2[:, importantGenes]
80
        
81
    # preprocessing scRNA-seq read counts matrix
82
    adata1 = sc.AnnData(x1)
83
    #adata1.obs['Group'] = y
84
85
    adata1 = read_dataset(adata1,
86
                     transpose=False,
87
                     test_split=False,
88
                     copy=True)
89
90
    adata1 = normalize(adata1,
91
                      size_factors=True,
92
                      normalize_input=True,
93
                      logtrans_input=True)
94
    
95
    adata2 = sc.AnnData(x2)
96
    #adata2.obs['Group'] = y
97
    adata2 = read_dataset(adata2,
98
                     transpose=False,
99
                     test_split=False,
100
                     copy=True)
101
    
102
    adata2 = normalize(adata2,
103
                      size_factors=True,
104
                      normalize_input=True,
105
                      logtrans_input=True)
106
107
    #adata2 = clr_normalize_each_cell(adata2)
108
109
    input_size1 = adata1.n_vars
110
    input_size2 = adata2.n_vars
111
    
112
    print(args)
113
    
114
    encodeLayer = list(map(int, args.encodeLayer))
115
    decodeLayer1 = list(map(int, args.decodeLayer1))
116
    decodeLayer2 = list(map(int, args.decodeLayer2))
117
    
118
    model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau,
119
                        encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
120
                        activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 
121
                        cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
122
    
123
    print(str(model))
124
    
125
    if not os.path.exists(args.save_dir):
126
            os.makedirs(args.save_dir)
127
            
128
    t0 = time()
129
    if args.ae_weights is None:
130
        model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
131
                X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, batch_size=args.batch_size, 
132
                epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file)
133
    else:
134
        if os.path.isfile(args.ae_weights):
135
            print("==> loading checkpoint '{}'".format(args.ae_weights))
136
            checkpoint = torch.load(args.ae_weights)
137
            model.load_state_dict(checkpoint['ae_state_dict'])
138
        else:
139
            print("==> no checkpoint found at '{}'".format(args.ae_weights))
140
            raise ValueError
141
    
142
    print('Pretraining time: %d seconds.' % int(time() - t0))
143
    
144
    #get k
145
    latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
146
    latent = latent.cpu().numpy()
147
    if args.n_clusters == -1:
148
       n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors)
149
    else:
150
       print("n_cluster is defined as " + str(args.n_clusters))
151
       n_clusters = args.n_clusters
152
153
    if not args.no_labels:
154
         y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
155
                 X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=y,
156
                 n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 
157
                 update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
158
    else:
159
         y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
160
                 X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=None,
161
                 n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 
162
                 update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
163
    print('Total time: %d seconds.' % int(time() - t0))
164
    
165
    if args.prediction_file:
166
       if not args.no_labels:
167
             y_pred_ = best_map(y, y_pred) - 1
168
             np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred_, delimiter=",")
169
       else:
170
             np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",")
171
    
172
    if args.embedding_file:
173
       final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
174
       final_latent = final_latent.cpu().numpy()
175
       np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",")
176
    
177
    if not args.no_labels:
178
         y_pred_ = best_map(y, y_pred)
179
         ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5)
180
         nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
181
         ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
182
         print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
183
    else:
184
         print("No labels for evaluation!")