Switch to side-by-side view

--- a
+++ b/code_final/cell2loc_estimate_signatures.py
@@ -0,0 +1,172 @@
+#!/usr/bin/env python
+
+import argparse
+
+parser = argparse.ArgumentParser(description='Prepare cell2location reference signatures')
+parser.add_argument("infile", type=str,default=None,help='input h5ad file with reference dataset')
+parser.add_argument("output", type=str,default=None,help='folder to write output')
+parser.add_argument("labels_key", type=str,default=None,help='column in adata.obs to be used as cell type label')
+parser.add_argument("--batch_key", type=str,default=None,help='column in adata.obs to be used as bacth (single 10x reaction)')
+parser.add_argument("--categorical_covariate_key",default=None, action='append',type=str,help='column in adata.obs to be used as categrical covariates - donor, 3/5, etc (no covariates by default). Multiple columns can be supplied by repetitive usage of this option.')
+parser.add_argument("--continuous_covariate_key",default=None, action='append',type=str,help='column in adata.obs to be used as categrical covariates (no covariates by default). Multiple columns can be supplied by repetitive usage of this option.')
+parser.add_argument("--gene_id", type=str,default=None,help='column in adata.var to be used as gene id')
+parser.add_argument("--cell_count_cutoff", type=int,default=5,help='Gene filtering parameter: All genes detected in less than cell_count_cutoff cells will be excluded.')
+parser.add_argument("--cell_percentage_cutoff2", type=float,default=0.03,help='Gene filtering parameter: All genes detected in at least this percentage of cells will be included.')
+parser.add_argument("--nonz_mean_cutoff", type=float,default=1.12,help='Gene filtering parameter: genes detected in the number of cells between the above mentioned cutoffs are selected only when their average expression in non-zero cells is above this cutoff.')
+parser.add_argument("--max_epochs", type=int,default=250,help='max_epochs for training')
+parser.add_argument("--remove_genes_column", type=str,default=None,help='logical column in adata.var to be used to remove genes, for example mitochonrial. All genes with True in the column will be removed. None (defualt) mean to remove nothing.')
+parser.add_argument("--seed", type=int,default=1,help='scvi seed value')
+
+args = parser.parse_args()
+
+import sys
+import os
+import scanpy as sc
+from scipy.sparse import issparse
+import anndata
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+
+import cell2location
+from cell2location.utils.filtering import filter_genes
+from cell2location.models import RegressionModel
+
+import torch
+import scvi
+from scvi import REGISTRY_KEYS
+
+
+#######################
+# create output folder
+os.mkdir(args.output)
+
+sys.stdout = open(args.output+"/c2l.ref.log", "w")
+print(args)
+print("cuda avaliable: "+str(torch.cuda.is_available()))
+scvi.settings.seed = args.seed
+
+# read data
+ref = sc.read_h5ad(args.infile)
+
+
+mtcnt = np.sum([gene.startswith('MT-') for gene in ref.var.index])
+if mtcnt > 0:
+    print('There are ' + str(mtcnt) + 'MT genes! Consider to remove them!')
+
+if args.gene_id is not None:
+  ref.var[args.gene_id] = ref.var[args.gene_id].astype('string')
+  ref.var=ref.var.set_index(args.gene_id)
+  print('Raw: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1]))
+
+# filter genes
+if args.remove_genes_column != None:
+  print('Remove genes by "'+args.remove_genes_column+'". Following genes were removed:')
+  print(ref.var[ref.var[args.remove_genes_column]])
+  ref = ref[:,~ref.var[args.remove_genes_column]]
+
+
+# filter genes 
+selected = filter_genes(ref,
+                        cell_count_cutoff=args.cell_count_cutoff,
+                        cell_percentage_cutoff2=args.cell_percentage_cutoff2,
+                        nonz_mean_cutoff=args.nonz_mean_cutoff)
+                        
+plt.savefig(args.output+'/gene.filter.pdf') 
+
+print('Before filtering: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1]))
+ref = ref[:, selected].copy()
+print('After filtering: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1]))
+
+# remove slashes from celltype names
+ref.obs[args.labels_key] = ref.obs[args.labels_key].astype(str).str.replace('/','_')
+
+# train
+cell2location.models.RegressionModel.setup_anndata(adata=ref,
+                        # 10X reaction / sample / batch
+                        batch_key=args.batch_key,
+                        # cell type, covariate used for constructing signatures
+                        labels_key=args.labels_key,
+                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
+                        categorical_covariate_keys=args.categorical_covariate_key,
+                        continuous_covariate_keys=args.continuous_covariate_key
+                       )
+
+mod = RegressionModel(ref)
+
+mod.view_anndata_setup()
+
+#mod.train(max_epochs=args.max_epochs,use_gpu=True,progress_bar_refresh_rate=0)
+mod.train(max_epochs=args.max_epochs,progress_bar_refresh_rate=0)
+
+# plot ELBO loss history during training, removing first 20 epochs from the plot
+fig, ax = plt.subplots()
+mod.plot_history(20)
+plt.savefig(args.output+'/train.history.pdf')
+
+ref = mod.export_posterior(
+    ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
+)
+mod.save(args.output+"/rsignatures", overwrite=True)
+# most likely I do not need this file
+ref.write(args.output+"/rsignatures/sc.h5ad")
+
+# save signatures
+inf_aver = ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}' for i in ref.uns['mod']['factor_names']]].copy()
+inf_aver.columns = ref.uns['mod']['factor_names']
+inf_aver.to_csv(args.output+'/rsignatures/inf_aver.csv')
+
+
+# function to plot QCs into file
+def plot_QC1(m,plot,summary_name: str = "means",use_n_obs: int = 1000):
+  if use_n_obs is not None:
+    ind_x = np.random.choice(m.adata_manager.adata.n_obs, np.min((use_n_obs, m.adata.n_obs)), replace=False)
+  else:
+    ind_x = None
+  m.expected_nb_param = m.module.model.compute_expected(
+    m.samples[f"post_sample_{summary_name}"], m.adata_manager, ind_x=ind_x
+  )
+  x_data = m.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)[ind_x, :]
+  if issparse(x_data):
+    x_data = np.asarray(x_data.toarray())
+  
+  mu = m.expected_nb_param["mu"]
+  data_node = x_data
+  plot.hist2d(np.log10(data_node.flatten()+1), np.log10(mu.flatten()+1), bins=50, norm=mpl.colors.LogNorm())
+  plot.set_title("Reconstruction accuracy")
+  plot.set(xlabel="Data, log10", ylabel="Posterior sample, values, log10")
+
+
+def plot_QC2(m,plot,summary_name: str = "means",use_n_obs: int = 1000,scale_average_detection: bool = True):
+  inf_aver = m.samples[f"post_sample_{summary_name}"]["per_cluster_mu_fg"].T
+  if scale_average_detection and ("detection_y_c" in list(m.samples[f"post_sample_{summary_name}"].keys())):
+    inf_aver = inf_aver * m.samples[f"post_sample_{summary_name}"]["detection_y_c"].mean()
+  aver = m._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY)
+  aver = aver[m.factor_names_]
+  plot.hist2d(
+    np.log10(aver.values.flatten() + 1),
+    np.log10(inf_aver.flatten() + 1),
+    bins=50,
+    norm=mpl.colors.LogNorm(),)
+  plot.set(xlabel="Mean expression for every gene in every cluster", ylabel="Estimated expression for every gene in every cluster")
+
+
+# unfortunatelly it may not work specifically in case of underpopulated covariates/cell_types. It cannot be fixed on this level, so I'll use "try"
+# see https://github.com/BayraktarLab/cell2location/issues/74
+fig, (ax1,ax2) = plt.subplots(1,2)
+try:
+    plot_QC1(mod,plot=ax1,use_n_obs=10000)
+except Exception as e:
+    print(e)
+
+try:
+    plot_QC2(mod,plot=ax2)
+except Exception as e:
+    print(e)
+
+plt.tight_layout()
+plt.savefig(args.output+'/train.QC.pdf')
+
+cell2location.utils.list_imported_modules()
+sys.stdout.close()