a b/src/run_scMDC_batch.py
1
from time import time
2
import math, os
3
from sklearn import metrics
4
from sklearn.cluster import KMeans
5
from sklearn.preprocessing import OneHotEncoder
6
import torch
7
import torch.nn as nn
8
from torch.autograd import Variable
9
from torch.nn import Parameter
10
import torch.nn.functional as F
11
import torch.optim as optim
12
from torch.utils.data import DataLoader, TensorDataset
13
14
from scMDC_batch import scMultiClusterBatch
15
import numpy as np
16
import collections
17
import h5py
18
import scanpy as sc
19
from preprocess import read_dataset, normalize
20
from utils import *
21
22
if __name__ == "__main__":
23
24
    # setting the hyper parameters
25
    import argparse
26
    parser = argparse.ArgumentParser(description='train',
27
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
    parser.add_argument('--n_clusters', default=27, type=int)
29
    parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
30
    parser.add_argument('--batch_size', default=256, type=int)
31
    parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5')
32
    parser.add_argument('--maxiter', default=5000, type=int)
33
    parser.add_argument('--pretrain_epochs', default=400, type=int)
34
    parser.add_argument('--gamma', default=.1, type=float,
35
                        help='coefficient of clustering loss')
36
    parser.add_argument('--tau', default=1., type=float,
37
                        help='weight of clustering loss')
38
    parser.add_argument('--phi1', default=0.001, type=float,
39
                        help='coefficient of KL loss in pretraining stage')
40
    parser.add_argument('--phi2', default=0.001, type=float,
41
                        help='coefficient of KL loss in clustering stage')
42
    parser.add_argument('--update_interval', default=1, type=int)
43
    parser.add_argument('--tol', default=0.001, type=float)
44
    parser.add_argument('--lr', default=1., type=float)
45
    parser.add_argument('--ae_weights', default=None)
46
    parser.add_argument('--save_dir', default='results/')
47
    parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
48
    parser.add_argument('--resolution', default=0.2, type=float)
49
    parser.add_argument('--n_neighbors', default=30, type=int)
50
    parser.add_argument('--embedding_file', action='store_true', default=False)
51
    parser.add_argument('--prediction_file', action='store_true', default=False)
52
    parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
53
    parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
54
    parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
55
    parser.add_argument('--sigma1', default=2.5, type=float)
56
    parser.add_argument('--sigma2', default=1.5, type=float)
57
    parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
58
    parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
59
    parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
60
    parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
61
    parser.add_argument('--nbatch', default=2, type=int)
62
    parser.add_argument('--run', default=1, type=int)
63
    parser.add_argument('--device', default='cuda')
64
    parser.add_argument('--no_labels', action='store_true', default=False)
65
    args = parser.parse_args()
66
    print(args)
67
    data_mat = h5py.File(args.data_file)
68
    x1 = np.array(data_mat['X1'])
69
    x2 = np.array(data_mat['X2'])
70
    if not args.no_labels:
71
         y = np.array(data_mat['Y'])
72
    b = np.array(data_mat['Batch'])
73
    enc = OneHotEncoder()
74
    enc.fit(b.reshape(-1, 1))
75
    B = enc.transform(b.reshape(-1, 1)).toarray()
76
    data_mat.close()
77
    
78
    #Gene filter
79
    if args.filter1:
80
        importantGenes = geneSelection(x1, n=args.f1, plot=False)
81
        x1 = x1[:, importantGenes]
82
    if args.filter2:
83
        importantGenes = geneSelection(x2, n=args.f2, plot=False)
84
        x2 = x2[:, importantGenes]
85
86
    # preprocessing scRNA-seq read counts matrix
87
    adata1 = sc.AnnData(x1)
88
    #adata1.obs['Group'] = y
89
90
    adata1 = read_dataset(adata1,
91
                     transpose=False,
92
                     test_split=False,
93
                     copy=True)
94
95
    adata1 = normalize(adata1,
96
                      size_factors=True,
97
                      normalize_input=True,
98
                      logtrans_input=True)
99
    
100
    adata2 = sc.AnnData(x2)
101
    #adata2.obs['Group'] = y
102
    adata2 = read_dataset(adata2,
103
                     transpose=False,
104
                     test_split=False,
105
                     copy=True)
106
107
    adata2 = normalize(adata2,
108
                      size_factors=True,
109
                      normalize_input=True,
110
                      logtrans_input=True)
111
112
    input_size1 = adata1.n_vars
113
    input_size2 = adata2.n_vars
114
    
115
    print(args)
116
    
117
    encodeLayer = list(map(int, args.encodeLayer))
118
    decodeLayer1 = list(map(int, args.decodeLayer1))
119
    decodeLayer2 = list(map(int, args.decodeLayer2))
120
    
121
    model = scMultiClusterBatch(input_dim1=input_size1, input_dim2=input_size2, n_batch = args.nbatch, tau=args.tau,
122
                        encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
123
                        activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma,
124
                        cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
125
    
126
    print(str(model))
127
    
128
    if not os.path.exists(args.save_dir):
129
        os.makedirs(args.save_dir)
130
131
132
    t0 = time()
133
    if args.ae_weights is None:
134
        model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
135
                X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B = B, batch_size=args.batch_size, 
136
                epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file)
137
    else:
138
        if os.path.isfile(args.ae_weights):
139
            print("==> loading checkpoint '{}'".format(args.ae_weights))
140
            checkpoint = torch.load(args.ae_weights)
141
            model.load_state_dict(checkpoint['ae_state_dict'])
142
        else:
143
            print("==> no checkpoint found at '{}'".format(args.ae_weights))
144
            raise ValueError
145
    
146
    print('Pretraining time: %d seconds.' % int(time() - t0))
147
    
148
    #get k
149
    latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device), torch.tensor(B).to(args.device), batch_size=args.batch_size)
150
    latent = latent.cpu().numpy()
151
    if args.n_clusters == -1:
152
       n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors)
153
    else:
154
       print("n_cluster is defined as " + str(args.n_clusters))
155
       n_clusters = args.n_clusters
156
157
    if not args.no_labels:
158
          y_pred,_ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
159
             X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B=B, y=y,
160
             n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 
161
             update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
162
    else:
163
          y_pred,_ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
164
             X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B=B, y=None,
165
             n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 
166
             update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
167
    print('Total time: %d seconds.' % int(time() - t0))
168
    
169
    if args.prediction_file:
170
       np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",")
171
    
172
    if args.embedding_file:
173
       final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device), torch.tensor(B).to(args.device), batch_size=args.batch_size)
174
       final_latent = final_latent.cpu().numpy()
175
       np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",")
176
177
    if not args.no_labels:
178
        y_pred_ = best_map(y, y_pred)
179
        ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5)
180
        nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
181
        ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
182
        print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
183
    else:
184
         print("No labels for evaluation!")