Switch to side-by-side view

--- a
+++ b/paper/Application on scRNA and scATAC datasets/run_VITAE.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+import pandas as pd
+import scanpy as sc
+import numpy as np
+import os
+import sys#; sys.path.insert(0, '../..')
+sys.path.append('/home/jinhongd/jingshu/VITAE-mm-pi-mm_added_pi')
+print(os.getcwd())
+import VITAE
+from VITAE.utils import load_data, reset_random_seeds
+import tensorflow as tf
+import random
+import h5py
+import matplotlib.pyplot as plt
+
+path = 'paper/An Application on scRNA and scATAC datasets/'
+os.makedirs(path, exist_ok=True)
+
+# Load data 
+adata_atac = load_data('data', 'human_hematopoiesis_scATAC')
+adata_rna = load_data('data', 'human_hematopoiesis_scRNA')
+
+celltype_exclude = ['CD4.M', 'CD4.N', 'CD8.CM', 'CD8.EM', 'CD8.N', 'NK', 'Plasma', 'cDC', 'CD16.Mono']
+adata_atac = adata_atac[~np.isin(adata_atac.obs['grouping'], celltype_exclude),:]
+adata_rna = adata_rna[~np.isin(adata_rna.obs['grouping'], celltype_exclude),:]
+
+# preprocess
+hvg = []
+for adata in [adata_atac, adata_rna]:  
+    dd = adata.copy()    
+    sc.pp.normalize_total(dd, target_sum=1e4)
+    sc.pp.log1p(dd)
+    hvg.append(
+        sc.pp.highly_variable_genes(dd, inplace=False))
+id_bool_genes = (hvg[0]['highly_variable']|hvg[1]['highly_variable']).values
+adata_atac = adata_atac[:,id_bool_genes]
+adata_rna = adata_rna[:,id_bool_genes]
+
+adata = adata_rna.concatenate(adata_atac, index_unique=None)
+adata.obs['id_dataset'] = adata.obs['batch'].cat.rename_categories({'0': 'scRNA', '1': 'scATAC'})
+adata.obs['location'] = adata.obs['covariate_0'].str.split('_', expand=True).iloc[:,0]
+adata.obs['location'] = adata.obs['location'].astype('category')
+adata.obs['tissue'] = adata.obs['covariate_0'].str.split('_', expand=True).iloc[:,1].str.split('T', expand=True).iloc[:,1]
+adata.obs['tissue'] = adata.obs['tissue'].astype('category')
+adata.obs['day'] = adata.obs['covariate_0'].str.split('_', expand=True).iloc[:,1].str.split('T', expand=True).iloc[:,0]
+adata.obs['day'] = adata.obs['day'].astype('category')
+
+sc.pp.normalize_total(adata, target_sum=1e4)
+sc.pp.log1p(adata)
+sc.pp.scale(adata, max_value=10)
+
+# merge small celltypes
+dict_merge = {
+    'Baso.Eryth':['Early.Baso','Early.Eryth', 'Late.Eryth'],
+    'GMP':['GMP', 'GMP.Neut']
+             }
+merged_groupings = adata.obs['grouping'].astype(str).values
+for key in dict_merge.keys():
+    merged_groupings[
+        np.isin(merged_groupings, dict_merge[key])] = key
+adata.obs["grouping"] = merged_groupings
+adata.obs["grouping"] = adata.obs["grouping"].astype("category")
+
+cond_group = np.unique(merged_groupings).astype(str)
+for group in cond_group:
+    col_name = 'cond_'+group    
+    adata.obs[col_name] = np.where(merged_groupings==group, adata.obs['id_dataset'].values, np.nan)
+cond = np.char.add('cond_', cond_group)
+adata.obs[cond] = adata.obs[cond].astype("category")
+
+
+# run the model
+reset_random_seeds(400)
+tf.keras.backend.clear_session() 
+model = VITAE.VITAE(adata = adata, covariates=['id_dataset'], conditions=cond,
+                    model_type = 'Gaussian', 
+                    npc=128, hidden_layers = [32,16], latent_space_dim=8)
+
+
+model.pre_train(gamma=0.6, phi=0.6, early_stopping_tolerance = 0.01, early_stopping_relative=True) 
+model.visualize_latent(color = ['id_dataset','grouping','location','tissue','day'], method = "UMAP")
+plt.savefig(path+"fig_pretrain.png", bbox_inches="tight")
+
+model.init_latent_space(cluster_label='grouping', ratio_prune=0.5)
+
+model.train(gamma=1., phi=1., early_stopping_tolerance = 0.01, early_stopping_relative=True)
+model.posterior_estimation()
+model.visualize_latent(color = ['vitae_new_clustering','grouping','id_dataset','location','tissue','day'], method = "UMAP")
+plt.savefig(path+"fig_train.png", bbox_inches="tight")
+model.infer_backbone(method = "modified_map")
+model.plot_backbone(color='grouping')
+plt.savefig(path+"fig_traj_modified_map.png", bbox_inches="tight")
+
+model.infer_backbone(method = "raw_map")
+model.plot_backbone(color='grouping')
+plt.savefig(path+"fig_traj_raw_map.png", bbox_inches="tight")
+
+model.save_model(
+    path_to_file=path+'weight/model_inference.checkpoint',
+    save_adata=True
+)