|
a |
|
b/code_final/cell2loc_estimate_signatures.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
|
|
|
5 |
parser = argparse.ArgumentParser(description='Prepare cell2location reference signatures') |
|
|
6 |
parser.add_argument("infile", type=str,default=None,help='input h5ad file with reference dataset') |
|
|
7 |
parser.add_argument("output", type=str,default=None,help='folder to write output') |
|
|
8 |
parser.add_argument("labels_key", type=str,default=None,help='column in adata.obs to be used as cell type label') |
|
|
9 |
parser.add_argument("--batch_key", type=str,default=None,help='column in adata.obs to be used as bacth (single 10x reaction)') |
|
|
10 |
parser.add_argument("--categorical_covariate_key",default=None, action='append',type=str,help='column in adata.obs to be used as categrical covariates - donor, 3/5, etc (no covariates by default). Multiple columns can be supplied by repetitive usage of this option.') |
|
|
11 |
parser.add_argument("--continuous_covariate_key",default=None, action='append',type=str,help='column in adata.obs to be used as categrical covariates (no covariates by default). Multiple columns can be supplied by repetitive usage of this option.') |
|
|
12 |
parser.add_argument("--gene_id", type=str,default=None,help='column in adata.var to be used as gene id') |
|
|
13 |
parser.add_argument("--cell_count_cutoff", type=int,default=5,help='Gene filtering parameter: All genes detected in less than cell_count_cutoff cells will be excluded.') |
|
|
14 |
parser.add_argument("--cell_percentage_cutoff2", type=float,default=0.03,help='Gene filtering parameter: All genes detected in at least this percentage of cells will be included.') |
|
|
15 |
parser.add_argument("--nonz_mean_cutoff", type=float,default=1.12,help='Gene filtering parameter: genes detected in the number of cells between the above mentioned cutoffs are selected only when their average expression in non-zero cells is above this cutoff.') |
|
|
16 |
parser.add_argument("--max_epochs", type=int,default=250,help='max_epochs for training') |
|
|
17 |
parser.add_argument("--remove_genes_column", type=str,default=None,help='logical column in adata.var to be used to remove genes, for example mitochonrial. All genes with True in the column will be removed. None (defualt) mean to remove nothing.') |
|
|
18 |
parser.add_argument("--seed", type=int,default=1,help='scvi seed value') |
|
|
19 |
|
|
|
20 |
args = parser.parse_args() |
|
|
21 |
|
|
|
22 |
import sys |
|
|
23 |
import os |
|
|
24 |
import scanpy as sc |
|
|
25 |
from scipy.sparse import issparse |
|
|
26 |
import anndata |
|
|
27 |
import pandas as pd |
|
|
28 |
import numpy as np |
|
|
29 |
import matplotlib.pyplot as plt |
|
|
30 |
import matplotlib as mpl |
|
|
31 |
|
|
|
32 |
import cell2location |
|
|
33 |
from cell2location.utils.filtering import filter_genes |
|
|
34 |
from cell2location.models import RegressionModel |
|
|
35 |
|
|
|
36 |
import torch |
|
|
37 |
import scvi |
|
|
38 |
from scvi import REGISTRY_KEYS |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
####################### |
|
|
42 |
# create output folder |
|
|
43 |
os.mkdir(args.output) |
|
|
44 |
|
|
|
45 |
sys.stdout = open(args.output+"/c2l.ref.log", "w") |
|
|
46 |
print(args) |
|
|
47 |
print("cuda avaliable: "+str(torch.cuda.is_available())) |
|
|
48 |
scvi.settings.seed = args.seed |
|
|
49 |
|
|
|
50 |
# read data |
|
|
51 |
ref = sc.read_h5ad(args.infile) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
mtcnt = np.sum([gene.startswith('MT-') for gene in ref.var.index]) |
|
|
55 |
if mtcnt > 0: |
|
|
56 |
print('There are ' + str(mtcnt) + 'MT genes! Consider to remove them!') |
|
|
57 |
|
|
|
58 |
if args.gene_id is not None: |
|
|
59 |
ref.var[args.gene_id] = ref.var[args.gene_id].astype('string') |
|
|
60 |
ref.var=ref.var.set_index(args.gene_id) |
|
|
61 |
print('Raw: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1])) |
|
|
62 |
|
|
|
63 |
# filter genes |
|
|
64 |
if args.remove_genes_column != None: |
|
|
65 |
print('Remove genes by "'+args.remove_genes_column+'". Following genes were removed:') |
|
|
66 |
print(ref.var[ref.var[args.remove_genes_column]]) |
|
|
67 |
ref = ref[:,~ref.var[args.remove_genes_column]] |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
# filter genes |
|
|
71 |
selected = filter_genes(ref, |
|
|
72 |
cell_count_cutoff=args.cell_count_cutoff, |
|
|
73 |
cell_percentage_cutoff2=args.cell_percentage_cutoff2, |
|
|
74 |
nonz_mean_cutoff=args.nonz_mean_cutoff) |
|
|
75 |
|
|
|
76 |
plt.savefig(args.output+'/gene.filter.pdf') |
|
|
77 |
|
|
|
78 |
print('Before filtering: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1])) |
|
|
79 |
ref = ref[:, selected].copy() |
|
|
80 |
print('After filtering: cells = '+str(ref.shape[0])+"; genes = " + str(ref.shape[1])) |
|
|
81 |
|
|
|
82 |
# remove slashes from celltype names |
|
|
83 |
ref.obs[args.labels_key] = ref.obs[args.labels_key].astype(str).str.replace('/','_') |
|
|
84 |
|
|
|
85 |
# train |
|
|
86 |
cell2location.models.RegressionModel.setup_anndata(adata=ref, |
|
|
87 |
# 10X reaction / sample / batch |
|
|
88 |
batch_key=args.batch_key, |
|
|
89 |
# cell type, covariate used for constructing signatures |
|
|
90 |
labels_key=args.labels_key, |
|
|
91 |
# multiplicative technical effects (platform, 3' vs 5', donor effect) |
|
|
92 |
categorical_covariate_keys=args.categorical_covariate_key, |
|
|
93 |
continuous_covariate_keys=args.continuous_covariate_key |
|
|
94 |
) |
|
|
95 |
|
|
|
96 |
mod = RegressionModel(ref) |
|
|
97 |
|
|
|
98 |
mod.view_anndata_setup() |
|
|
99 |
|
|
|
100 |
#mod.train(max_epochs=args.max_epochs,use_gpu=True,progress_bar_refresh_rate=0) |
|
|
101 |
mod.train(max_epochs=args.max_epochs,progress_bar_refresh_rate=0) |
|
|
102 |
|
|
|
103 |
# plot ELBO loss history during training, removing first 20 epochs from the plot |
|
|
104 |
fig, ax = plt.subplots() |
|
|
105 |
mod.plot_history(20) |
|
|
106 |
plt.savefig(args.output+'/train.history.pdf') |
|
|
107 |
|
|
|
108 |
ref = mod.export_posterior( |
|
|
109 |
ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True} |
|
|
110 |
) |
|
|
111 |
mod.save(args.output+"/rsignatures", overwrite=True) |
|
|
112 |
# most likely I do not need this file |
|
|
113 |
ref.write(args.output+"/rsignatures/sc.h5ad") |
|
|
114 |
|
|
|
115 |
# save signatures |
|
|
116 |
inf_aver = ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}' for i in ref.uns['mod']['factor_names']]].copy() |
|
|
117 |
inf_aver.columns = ref.uns['mod']['factor_names'] |
|
|
118 |
inf_aver.to_csv(args.output+'/rsignatures/inf_aver.csv') |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
# function to plot QCs into file |
|
|
122 |
def plot_QC1(m,plot,summary_name: str = "means",use_n_obs: int = 1000): |
|
|
123 |
if use_n_obs is not None: |
|
|
124 |
ind_x = np.random.choice(m.adata_manager.adata.n_obs, np.min((use_n_obs, m.adata.n_obs)), replace=False) |
|
|
125 |
else: |
|
|
126 |
ind_x = None |
|
|
127 |
m.expected_nb_param = m.module.model.compute_expected( |
|
|
128 |
m.samples[f"post_sample_{summary_name}"], m.adata_manager, ind_x=ind_x |
|
|
129 |
) |
|
|
130 |
x_data = m.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)[ind_x, :] |
|
|
131 |
if issparse(x_data): |
|
|
132 |
x_data = np.asarray(x_data.toarray()) |
|
|
133 |
|
|
|
134 |
mu = m.expected_nb_param["mu"] |
|
|
135 |
data_node = x_data |
|
|
136 |
plot.hist2d(np.log10(data_node.flatten()+1), np.log10(mu.flatten()+1), bins=50, norm=mpl.colors.LogNorm()) |
|
|
137 |
plot.set_title("Reconstruction accuracy") |
|
|
138 |
plot.set(xlabel="Data, log10", ylabel="Posterior sample, values, log10") |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
def plot_QC2(m,plot,summary_name: str = "means",use_n_obs: int = 1000,scale_average_detection: bool = True): |
|
|
142 |
inf_aver = m.samples[f"post_sample_{summary_name}"]["per_cluster_mu_fg"].T |
|
|
143 |
if scale_average_detection and ("detection_y_c" in list(m.samples[f"post_sample_{summary_name}"].keys())): |
|
|
144 |
inf_aver = inf_aver * m.samples[f"post_sample_{summary_name}"]["detection_y_c"].mean() |
|
|
145 |
aver = m._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY) |
|
|
146 |
aver = aver[m.factor_names_] |
|
|
147 |
plot.hist2d( |
|
|
148 |
np.log10(aver.values.flatten() + 1), |
|
|
149 |
np.log10(inf_aver.flatten() + 1), |
|
|
150 |
bins=50, |
|
|
151 |
norm=mpl.colors.LogNorm(),) |
|
|
152 |
plot.set(xlabel="Mean expression for every gene in every cluster", ylabel="Estimated expression for every gene in every cluster") |
|
|
153 |
|
|
|
154 |
|
|
|
155 |
# unfortunatelly it may not work specifically in case of underpopulated covariates/cell_types. It cannot be fixed on this level, so I'll use "try" |
|
|
156 |
# see https://github.com/BayraktarLab/cell2location/issues/74 |
|
|
157 |
fig, (ax1,ax2) = plt.subplots(1,2) |
|
|
158 |
try: |
|
|
159 |
plot_QC1(mod,plot=ax1,use_n_obs=10000) |
|
|
160 |
except Exception as e: |
|
|
161 |
print(e) |
|
|
162 |
|
|
|
163 |
try: |
|
|
164 |
plot_QC2(mod,plot=ax2) |
|
|
165 |
except Exception as e: |
|
|
166 |
print(e) |
|
|
167 |
|
|
|
168 |
plt.tight_layout() |
|
|
169 |
plt.savefig(args.output+'/train.QC.pdf') |
|
|
170 |
|
|
|
171 |
cell2location.utils.list_imported_modules() |
|
|
172 |
sys.stdout.close() |