Diff of /src/run_scMDC.py [000000] .. [ac720d]

Switch to side-by-side view

--- a
+++ b/src/run_scMDC.py
@@ -0,0 +1,184 @@
+from time import time
+import math, os
+from sklearn import metrics
+from sklearn.cluster import KMeans
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.nn import Parameter
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+from scMDC import scMultiCluster
+import numpy as np
+import collections
+import h5py
+import scanpy as sc
+from preprocess import read_dataset, normalize, clr_normalize_each_cell
+from utils import *
+
+if __name__ == "__main__":
+
+    # setting the hyper parameters
+    import argparse
+    parser = argparse.ArgumentParser(description='train',
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--n_clusters', default=27, type=int)
+    parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
+    parser.add_argument('--batch_size', default=256, type=int)
+    parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5')
+    parser.add_argument('--maxiter', default=5000, type=int)
+    parser.add_argument('--pretrain_epochs', default=400, type=int)
+    parser.add_argument('--gamma', default=.1, type=float,
+                        help='coefficient of clustering loss')
+    parser.add_argument('--tau', default=1., type=float,
+                        help='fuzziness of clustering loss')                    
+    parser.add_argument('--phi1', default=0.001, type=float,
+                        help='coefficient of KL loss in pretraining stage')
+    parser.add_argument('--phi2', default=0.001, type=float,
+                        help='coefficient of KL loss in clustering stage')
+    parser.add_argument('--update_interval', default=1, type=int)
+    parser.add_argument('--tol', default=0.001, type=float)
+    parser.add_argument('--lr', default=1., type=float)
+    parser.add_argument('--ae_weights', default=None)
+    parser.add_argument('--save_dir', default='results/')
+    parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
+    parser.add_argument('--resolution', default=0.2, type=float)
+    parser.add_argument('--n_neighbors', default=30, type=int)
+    parser.add_argument('--embedding_file', action='store_true', default=False)
+    parser.add_argument('--prediction_file', action='store_true', default=False)
+    parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
+    parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
+    parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
+    parser.add_argument('--sigma1', default=2.5, type=float)
+    parser.add_argument('--sigma2', default=1.5, type=float)
+    parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
+    parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
+    parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
+    parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
+    parser.add_argument('--run', default=1, type=int)
+    parser.add_argument('--device', default='cuda')
+    parser.add_argument('--no_labels', action='store_true', default=False)
+    args = parser.parse_args()
+    print(args)
+    
+    data_mat = h5py.File(args.data_file)
+    x1 = np.array(data_mat['X1'])
+    x2 = np.array(data_mat['X2'])
+    if not args.no_labels:
+         y = np.array(data_mat['Y'])
+    data_mat.close()
+
+    #Gene filter
+    if args.filter1:
+        importantGenes = geneSelection(x1, n=args.f1, plot=False)
+        x1 = x1[:, importantGenes]
+    if args.filter2:
+        importantGenes = geneSelection(x2, n=args.f2, plot=False)
+        x2 = x2[:, importantGenes]
+        
+    # preprocessing scRNA-seq read counts matrix
+    adata1 = sc.AnnData(x1)
+    #adata1.obs['Group'] = y
+
+    adata1 = read_dataset(adata1,
+                     transpose=False,
+                     test_split=False,
+                     copy=True)
+
+    adata1 = normalize(adata1,
+                      size_factors=True,
+                      normalize_input=True,
+                      logtrans_input=True)
+    
+    adata2 = sc.AnnData(x2)
+    #adata2.obs['Group'] = y
+    adata2 = read_dataset(adata2,
+                     transpose=False,
+                     test_split=False,
+                     copy=True)
+    
+    adata2 = normalize(adata2,
+                      size_factors=True,
+                      normalize_input=True,
+                      logtrans_input=True)
+
+    #adata2 = clr_normalize_each_cell(adata2)
+
+    input_size1 = adata1.n_vars
+    input_size2 = adata2.n_vars
+    
+    print(args)
+    
+    encodeLayer = list(map(int, args.encodeLayer))
+    decodeLayer1 = list(map(int, args.decodeLayer1))
+    decodeLayer2 = list(map(int, args.decodeLayer2))
+    
+    model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau,
+                        encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
+                        activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 
+                        cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
+    
+    print(str(model))
+    
+    if not os.path.exists(args.save_dir):
+            os.makedirs(args.save_dir)
+            
+    t0 = time()
+    if args.ae_weights is None:
+        model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
+                X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, batch_size=args.batch_size, 
+                epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file)
+    else:
+        if os.path.isfile(args.ae_weights):
+            print("==> loading checkpoint '{}'".format(args.ae_weights))
+            checkpoint = torch.load(args.ae_weights)
+            model.load_state_dict(checkpoint['ae_state_dict'])
+        else:
+            print("==> no checkpoint found at '{}'".format(args.ae_weights))
+            raise ValueError
+    
+    print('Pretraining time: %d seconds.' % int(time() - t0))
+    
+    #get k
+    latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
+    latent = latent.cpu().numpy()
+    if args.n_clusters == -1:
+       n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors)
+    else:
+       print("n_cluster is defined as " + str(args.n_clusters))
+       n_clusters = args.n_clusters
+
+    if not args.no_labels:
+         y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
+                 X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=y,
+                 n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 
+                 update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
+    else:
+         y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
+                 X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=None,
+                 n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 
+                 update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
+    print('Total time: %d seconds.' % int(time() - t0))
+    
+    if args.prediction_file:
+       if not args.no_labels:
+             y_pred_ = best_map(y, y_pred) - 1
+             np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred_, delimiter=",")
+       else:
+             np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",")
+    
+    if args.embedding_file:
+       final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
+       final_latent = final_latent.cpu().numpy()
+       np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",")
+    
+    if not args.no_labels:
+         y_pred_ = best_map(y, y_pred)
+         ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5)
+         nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
+         ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
+         print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
+    else:
+         print("No labels for evaluation!")