a b/code_final/cell2loc_estimate_signatures.py
1
#!/usr/bin/env python
2
3
import argparse
4
5
parser = argparse.ArgumentParser(description='Prepare cell2location reference signatures')
6
parser.add_argument("infile", type=str,default=None,help='input h5ad file with reference dataset')
7
parser.add_argument("output", type=str,default=None,help='folder to write output')
8
parser.add_argument("labels_key", type=str,default=None,help='column in adata.obs to be used as cell type label')
9
parser.add_argument("--batch_key", type=str,default=None,help='column in adata.obs to be used as bacth (single 10x reaction)')
10
parser.add_argument("--categorical_covariate_key",default=None, action='append',type=str,help='column in adata.obs to be used as categrical covariates - donor, 3/5, etc (no covariates by default). Multiple columns can be supplied by repetitive usage of this option.')
11
parser.add_argument("--continuous_covariate_key",default=None, action='append',type=str,help='column in adata.obs to be used as categrical covariates (no covariates by default). Multiple columns can be supplied by repetitive usage of this option.')
12
parser.add_argument("--gene_id", type=str,default=None,help='column in adata.var to be used as gene id')
13
parser.add_argument("--cell_count_cutoff", type=int,default=5,help='Gene filtering parameter: All genes detected in less than cell_count_cutoff cells will be excluded.')
14
parser.add_argument("--cell_percentage_cutoff2", type=float,default=0.03,help='Gene filtering parameter: All genes detected in at least this percentage of cells will be included.')
15
parser.add_argument("--nonz_mean_cutoff", type=float,default=1.12,help='Gene filtering parameter: genes detected in the number of cells between the above mentioned cutoffs are selected only when their average expression in non-zero cells is above this cutoff.')
16
parser.add_argument("--max_epochs", type=int,default=250,help='max_epochs for training')
17
parser.add_argument("--remove_genes_column", type=str,default=None,help='logical column in adata.var to be used to remove genes, for example mitochonrial. All genes with True in the column will be removed. None (defualt) mean to remove nothing.')
18
parser.add_argument("--seed", type=int,default=1,help='scvi seed value')
19
20
args = parser.parse_args()
21
22
import sys
23
import os
24
import scanpy as sc
25
from scipy.sparse import issparse
26
import anndata
27
import pandas as pd
28
import numpy as np
29
import matplotlib.pyplot as plt
30
import matplotlib as mpl
31
32
import cell2location
33
from cell2location.utils.filtering import filter_genes
34
from cell2location.models import RegressionModel
35
36
import torch
37
import scvi
38
from scvi import REGISTRY_KEYS
39
40
41
#######################
42
# create output folder
43
os.mkdir(args.output)
44
45
sys.stdout = open(args.output+"/c2l.ref.log", "w")
46
print(args)
47
print("cuda avaliable: "+str(torch.cuda.is_available()))
48
scvi.settings.seed = args.seed
49
50
# read data
51
ref = sc.read_h5ad(args.infile)
52
53
54
mtcnt = np.sum([gene.startswith('MT-') for gene in ref.var.index])
55
if mtcnt > 0:
56
    print('There are ' + str(mtcnt) + 'MT genes! Consider to remove them!')
57
58
if args.gene_id is not None:
59
  ref.var[args.gene_id] = ref.var[args.gene_id].astype('string')
60
  ref.var=ref.var.set_index(args.gene_id)
61
  print('Raw: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1]))
62
63
# filter genes
64
if args.remove_genes_column != None:
65
  print('Remove genes by "'+args.remove_genes_column+'". Following genes were removed:')
66
  print(ref.var[ref.var[args.remove_genes_column]])
67
  ref = ref[:,~ref.var[args.remove_genes_column]]
68
69
70
# filter genes 
71
selected = filter_genes(ref,
72
                        cell_count_cutoff=args.cell_count_cutoff,
73
                        cell_percentage_cutoff2=args.cell_percentage_cutoff2,
74
                        nonz_mean_cutoff=args.nonz_mean_cutoff)
75
                        
76
plt.savefig(args.output+'/gene.filter.pdf') 
77
78
print('Before filtering: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1]))
79
ref = ref[:, selected].copy()
80
print('After filtering: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1]))
81
82
# remove slashes from celltype names
83
ref.obs[args.labels_key] = ref.obs[args.labels_key].astype(str).str.replace('/','_')
84
85
# train
86
cell2location.models.RegressionModel.setup_anndata(adata=ref,
87
                        # 10X reaction / sample / batch
88
                        batch_key=args.batch_key,
89
                        # cell type, covariate used for constructing signatures
90
                        labels_key=args.labels_key,
91
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
92
                        categorical_covariate_keys=args.categorical_covariate_key,
93
                        continuous_covariate_keys=args.continuous_covariate_key
94
                       )
95
96
mod = RegressionModel(ref)
97
98
mod.view_anndata_setup()
99
100
#mod.train(max_epochs=args.max_epochs,use_gpu=True,progress_bar_refresh_rate=0)
101
mod.train(max_epochs=args.max_epochs,progress_bar_refresh_rate=0)
102
103
# plot ELBO loss history during training, removing first 20 epochs from the plot
104
fig, ax = plt.subplots()
105
mod.plot_history(20)
106
plt.savefig(args.output+'/train.history.pdf')
107
108
ref = mod.export_posterior(
109
    ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
110
)
111
mod.save(args.output+"/rsignatures", overwrite=True)
112
# most likely I do not need this file
113
ref.write(args.output+"/rsignatures/sc.h5ad")
114
115
# save signatures
116
inf_aver = ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}' for i in ref.uns['mod']['factor_names']]].copy()
117
inf_aver.columns = ref.uns['mod']['factor_names']
118
inf_aver.to_csv(args.output+'/rsignatures/inf_aver.csv')
119
120
121
# function to plot QCs into file
122
def plot_QC1(m,plot,summary_name: str = "means",use_n_obs: int = 1000):
123
  if use_n_obs is not None:
124
    ind_x = np.random.choice(m.adata_manager.adata.n_obs, np.min((use_n_obs, m.adata.n_obs)), replace=False)
125
  else:
126
    ind_x = None
127
  m.expected_nb_param = m.module.model.compute_expected(
128
    m.samples[f"post_sample_{summary_name}"], m.adata_manager, ind_x=ind_x
129
  )
130
  x_data = m.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)[ind_x, :]
131
  if issparse(x_data):
132
    x_data = np.asarray(x_data.toarray())
133
  
134
  mu = m.expected_nb_param["mu"]
135
  data_node = x_data
136
  plot.hist2d(np.log10(data_node.flatten()+1), np.log10(mu.flatten()+1), bins=50, norm=mpl.colors.LogNorm())
137
  plot.set_title("Reconstruction accuracy")
138
  plot.set(xlabel="Data, log10", ylabel="Posterior sample, values, log10")
139
140
141
def plot_QC2(m,plot,summary_name: str = "means",use_n_obs: int = 1000,scale_average_detection: bool = True):
142
  inf_aver = m.samples[f"post_sample_{summary_name}"]["per_cluster_mu_fg"].T
143
  if scale_average_detection and ("detection_y_c" in list(m.samples[f"post_sample_{summary_name}"].keys())):
144
    inf_aver = inf_aver * m.samples[f"post_sample_{summary_name}"]["detection_y_c"].mean()
145
  aver = m._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY)
146
  aver = aver[m.factor_names_]
147
  plot.hist2d(
148
    np.log10(aver.values.flatten() + 1),
149
    np.log10(inf_aver.flatten() + 1),
150
    bins=50,
151
    norm=mpl.colors.LogNorm(),)
152
  plot.set(xlabel="Mean expression for every gene in every cluster", ylabel="Estimated expression for every gene in every cluster")
153
154
155
# unfortunatelly it may not work specifically in case of underpopulated covariates/cell_types. It cannot be fixed on this level, so I'll use "try"
156
# see https://github.com/BayraktarLab/cell2location/issues/74
157
fig, (ax1,ax2) = plt.subplots(1,2)
158
try:
159
    plot_QC1(mod,plot=ax1,use_n_obs=10000)
160
except Exception as e:
161
    print(e)
162
163
try:
164
    plot_QC2(mod,plot=ax2)
165
except Exception as e:
166
    print(e)
167
168
plt.tight_layout()
169
plt.savefig(args.output+'/train.QC.pdf')
170
171
cell2location.utils.list_imported_modules()
172
sys.stdout.close()