|
a |
|
b/bin/predict_model.py |
|
|
1 |
""" |
|
|
2 |
Code for evaluating a model's ability to generalize to cells that it wasn't trained on. |
|
|
3 |
Can only be used to evalute within a species. |
|
|
4 |
Generates raw predictions of data modality transfer, and optionally, plots. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import os |
|
|
8 |
import sys |
|
|
9 |
from typing import * |
|
|
10 |
import functools |
|
|
11 |
import logging |
|
|
12 |
import argparse |
|
|
13 |
import copy |
|
|
14 |
|
|
|
15 |
import scipy |
|
|
16 |
|
|
|
17 |
import anndata as ad |
|
|
18 |
import scanpy as sc |
|
|
19 |
|
|
|
20 |
import torch |
|
|
21 |
import skorch |
|
|
22 |
|
|
|
23 |
SRC_DIR = os.path.join( |
|
|
24 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), |
|
|
25 |
"babel", |
|
|
26 |
) |
|
|
27 |
assert os.path.isdir(SRC_DIR) |
|
|
28 |
sys.path.append(SRC_DIR) |
|
|
29 |
import sc_data_loaders |
|
|
30 |
import loss_functions |
|
|
31 |
import model_utils |
|
|
32 |
import plot_utils |
|
|
33 |
import adata_utils |
|
|
34 |
import utils |
|
|
35 |
from models import autoencoders |
|
|
36 |
|
|
|
37 |
DATA_DIR = os.path.join(os.path.dirname(SRC_DIR), "data") |
|
|
38 |
assert os.path.isdir(DATA_DIR) |
|
|
39 |
|
|
|
40 |
logging.basicConfig(level=logging.INFO) |
|
|
41 |
|
|
|
42 |
DATASET_NAME = "" |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def do_evaluation_rna_from_rna( |
|
|
46 |
spliced_net, |
|
|
47 |
sc_dual_full_dataset, |
|
|
48 |
gene_names: str, |
|
|
49 |
atac_names: str, |
|
|
50 |
outdir: str, |
|
|
51 |
ext: str, |
|
|
52 |
marker_genes: List[str], |
|
|
53 |
prefix: str = "", |
|
|
54 |
): |
|
|
55 |
""" |
|
|
56 |
Evaluate the given network on the dataset |
|
|
57 |
""" |
|
|
58 |
# Do inference and plotting |
|
|
59 |
### RNA > RNA |
|
|
60 |
logging.info("Inferring RNA from RNA...") |
|
|
61 |
sc_rna_full_preds = spliced_net.translate_1_to_1(sc_dual_full_dataset) |
|
|
62 |
sc_rna_full_preds_anndata = sc.AnnData( |
|
|
63 |
sc_rna_full_preds, |
|
|
64 |
obs=sc_dual_full_dataset.dataset_x.data_raw.obs, |
|
|
65 |
) |
|
|
66 |
sc_rna_full_preds_anndata.var_names = gene_names |
|
|
67 |
|
|
|
68 |
logging.info("Writing RNA from RNA") |
|
|
69 |
sc_rna_full_preds_anndata.write( |
|
|
70 |
os.path.join(outdir, f"{prefix}_rna_rna_adata.h5ad".strip("_")) |
|
|
71 |
) |
|
|
72 |
if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None: |
|
|
73 |
logging.info("Plotting RNA from RNA") |
|
|
74 |
plot_utils.plot_scatter_with_r( |
|
|
75 |
sc_dual_full_dataset.dataset_x.size_norm_counts.X, |
|
|
76 |
sc_rna_full_preds, |
|
|
77 |
one_to_one=True, |
|
|
78 |
logscale=True, |
|
|
79 |
density_heatmap=True, |
|
|
80 |
title=f"{DATASET_NAME} RNA > RNA".strip(), |
|
|
81 |
fname=os.path.join(outdir, f"{prefix}_rna_rna_log.{ext}".strip("_")), |
|
|
82 |
) |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
def do_evaluation_atac_from_rna( |
|
|
86 |
spliced_net, |
|
|
87 |
sc_dual_full_dataset, |
|
|
88 |
gene_names: str, |
|
|
89 |
atac_names: str, |
|
|
90 |
outdir: str, |
|
|
91 |
ext: str, |
|
|
92 |
marker_genes: List[str], |
|
|
93 |
prefix: str = "", |
|
|
94 |
): |
|
|
95 |
### RNA > ATAC |
|
|
96 |
logging.info("Inferring ATAC from RNA") |
|
|
97 |
sc_rna_atac_full_preds = spliced_net.translate_1_to_2(sc_dual_full_dataset) |
|
|
98 |
sc_rna_atac_full_preds_anndata = sc.AnnData( |
|
|
99 |
scipy.sparse.csr_matrix(sc_rna_atac_full_preds), |
|
|
100 |
obs=sc_dual_full_dataset.dataset_x.data_raw.obs, |
|
|
101 |
) |
|
|
102 |
sc_rna_atac_full_preds_anndata.var_names = atac_names |
|
|
103 |
logging.info("Writing ATAC from RNA") |
|
|
104 |
sc_rna_atac_full_preds_anndata.write( |
|
|
105 |
os.path.join(outdir, f"{prefix}_rna_atac_adata.h5ad".strip("_")) |
|
|
106 |
) |
|
|
107 |
|
|
|
108 |
if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None: |
|
|
109 |
logging.info("Plotting ATAC from RNA") |
|
|
110 |
plot_utils.plot_auroc( |
|
|
111 |
utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), |
|
|
112 |
utils.ensure_arr(sc_rna_atac_full_preds).flatten(), |
|
|
113 |
title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(), |
|
|
114 |
fname=os.path.join(outdir, f"{prefix}_rna_atac_auroc.{ext}".strip("_")), |
|
|
115 |
) |
|
|
116 |
# plot_utils.plot_auprc( |
|
|
117 |
# utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), |
|
|
118 |
# utils.ensure_arr(sc_rna_atac_full_preds), |
|
|
119 |
# title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(), |
|
|
120 |
# fname=os.path.join(outdir, f"{prefix}_rna_atac_auprc.{ext}".strip("_")), |
|
|
121 |
# ) |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
def do_evaluation_atac_from_atac( |
|
|
125 |
spliced_net, |
|
|
126 |
sc_dual_full_dataset, |
|
|
127 |
gene_names: str, |
|
|
128 |
atac_names: str, |
|
|
129 |
outdir: str, |
|
|
130 |
ext: str, |
|
|
131 |
marker_genes: List[str], |
|
|
132 |
prefix: str = "", |
|
|
133 |
): |
|
|
134 |
### ATAC > ATAC |
|
|
135 |
logging.info("Inferring ATAC from ATAC") |
|
|
136 |
sc_atac_full_preds = spliced_net.translate_2_to_2(sc_dual_full_dataset) |
|
|
137 |
sc_atac_full_preds_anndata = sc.AnnData( |
|
|
138 |
sc_atac_full_preds, |
|
|
139 |
obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True), |
|
|
140 |
) |
|
|
141 |
sc_atac_full_preds_anndata.var_names = atac_names |
|
|
142 |
logging.info("Writing ATAC from ATAC") |
|
|
143 |
|
|
|
144 |
# Infer marker bins |
|
|
145 |
# logging.info("Getting marker bins for ATAC from ATAC") |
|
|
146 |
# plot_utils.preprocess_anndata(sc_atac_full_preds_anndata) |
|
|
147 |
# adata_utils.find_marker_genes(sc_atac_full_preds_anndata) |
|
|
148 |
# inferred_marker_bins = adata_utils.flatten_marker_genes( |
|
|
149 |
# sc_atac_full_preds_anndata.uns["rank_genes_leiden"] |
|
|
150 |
# ) |
|
|
151 |
# logging.info(f"Found {len(inferred_marker_bins)} marker bins for ATAC from ATAC") |
|
|
152 |
# with open( |
|
|
153 |
# os.path.join(outdir, f"{prefix}_atac_atac_marker_bins.txt".strip("_")), "w" |
|
|
154 |
# ) as sink: |
|
|
155 |
# sink.write("\n".join(inferred_marker_bins) + "\n") |
|
|
156 |
|
|
|
157 |
sc_atac_full_preds_anndata.write( |
|
|
158 |
os.path.join(outdir, f"{prefix}_atac_atac_adata.h5ad".strip("_")) |
|
|
159 |
) |
|
|
160 |
if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None: |
|
|
161 |
logging.info("Plotting ATAC from ATAC") |
|
|
162 |
plot_utils.plot_auroc( |
|
|
163 |
utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), |
|
|
164 |
utils.ensure_arr(sc_atac_full_preds).flatten(), |
|
|
165 |
title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(), |
|
|
166 |
fname=os.path.join(outdir, f"{prefix}_atac_atac_auroc.{ext}".strip("_")), |
|
|
167 |
) |
|
|
168 |
# plot_utils.plot_auprc( |
|
|
169 |
# utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), |
|
|
170 |
# utils.ensure_arr(sc_atac_full_preds).flatten(), |
|
|
171 |
# title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(), |
|
|
172 |
# fname=os.path.join(outdir, f"{prefix}_atac_atac_auprc.{ext}".strip("_")), |
|
|
173 |
# ) |
|
|
174 |
|
|
|
175 |
# Remove some objects to free memory |
|
|
176 |
del sc_atac_full_preds |
|
|
177 |
del sc_atac_full_preds_anndata |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
def do_evaluation_rna_from_atac( |
|
|
181 |
spliced_net, |
|
|
182 |
sc_dual_full_dataset, |
|
|
183 |
gene_names: str, |
|
|
184 |
atac_names: str, |
|
|
185 |
outdir: str, |
|
|
186 |
ext: str, |
|
|
187 |
marker_genes: List[str], |
|
|
188 |
prefix: str = "", |
|
|
189 |
): |
|
|
190 |
### ATAC > RNA |
|
|
191 |
logging.info("Inferring RNA from ATAC") |
|
|
192 |
sc_atac_rna_full_preds = spliced_net.translate_2_to_1(sc_dual_full_dataset) |
|
|
193 |
# Seurat expects everything to be sparse |
|
|
194 |
# https://github.com/satijalab/seurat/issues/2228 |
|
|
195 |
sc_atac_rna_full_preds_anndata = sc.AnnData( |
|
|
196 |
sc_atac_rna_full_preds, |
|
|
197 |
obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True), |
|
|
198 |
) |
|
|
199 |
sc_atac_rna_full_preds_anndata.var_names = gene_names |
|
|
200 |
logging.info("Writing RNA from ATAC") |
|
|
201 |
|
|
|
202 |
# Seurat also expects the raw attribute to be populated |
|
|
203 |
sc_atac_rna_full_preds_anndata.raw = sc_atac_rna_full_preds_anndata.copy() |
|
|
204 |
sc_atac_rna_full_preds_anndata.write( |
|
|
205 |
os.path.join(outdir, f"{prefix}_atac_rna_adata.h5ad".strip("_")) |
|
|
206 |
) |
|
|
207 |
# sc_atac_rna_full_preds_anndata.write_csvs( |
|
|
208 |
# os.path.join(outdir, f"{prefix}_atac_rna_constituent_csv".strip("_")), |
|
|
209 |
# skip_data=False, |
|
|
210 |
# ) |
|
|
211 |
# sc_atac_rna_full_preds_anndata.to_df().to_csv( |
|
|
212 |
# os.path.join(outdir, f"{prefix}_atac_rna_table.csv".strip("_")) |
|
|
213 |
# ) |
|
|
214 |
|
|
|
215 |
# If there eixsts a ground truth RNA, do RNA plotting |
|
|
216 |
if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None: |
|
|
217 |
logging.info("Plotting RNA from ATAC") |
|
|
218 |
plot_utils.plot_scatter_with_r( |
|
|
219 |
sc_dual_full_dataset.dataset_x.size_norm_counts.X, |
|
|
220 |
sc_atac_rna_full_preds, |
|
|
221 |
one_to_one=True, |
|
|
222 |
logscale=True, |
|
|
223 |
density_heatmap=True, |
|
|
224 |
title=f"{DATASET_NAME} ATAC > RNA".strip(), |
|
|
225 |
fname=os.path.join(outdir, f"{prefix}_atac_rna_log.{ext}".strip("_")), |
|
|
226 |
) |
|
|
227 |
|
|
|
228 |
# Remove objects to free memory |
|
|
229 |
del sc_atac_rna_full_preds |
|
|
230 |
del sc_atac_rna_full_preds_anndata |
|
|
231 |
|
|
|
232 |
|
|
|
233 |
def do_latent_evaluation( |
|
|
234 |
spliced_net, sc_dual_full_dataset, outdir: str, prefix: str = "" |
|
|
235 |
): |
|
|
236 |
""" |
|
|
237 |
Pull out latent space and write to file |
|
|
238 |
""" |
|
|
239 |
logging.info("Inferring latent representations") |
|
|
240 |
encoded_from_rna, encoded_from_atac = spliced_net.get_encoded_layer( |
|
|
241 |
sc_dual_full_dataset |
|
|
242 |
) |
|
|
243 |
|
|
|
244 |
if hasattr(sc_dual_full_dataset.dataset_x, "data_raw"): |
|
|
245 |
encoded_from_rna_adata = sc.AnnData( |
|
|
246 |
encoded_from_rna, |
|
|
247 |
obs=sc_dual_full_dataset.dataset_x.data_raw.obs.copy(deep=True), |
|
|
248 |
) |
|
|
249 |
encoded_from_rna_adata.write( |
|
|
250 |
os.path.join(outdir, f"{prefix}_rna_encoded_adata.h5ad".strip("_")) |
|
|
251 |
) |
|
|
252 |
if hasattr(sc_dual_full_dataset.dataset_y, "data_raw"): |
|
|
253 |
encoded_from_atac_adata = sc.AnnData( |
|
|
254 |
encoded_from_atac, |
|
|
255 |
obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True), |
|
|
256 |
) |
|
|
257 |
encoded_from_atac_adata.write( |
|
|
258 |
os.path.join(outdir, f"{prefix}_atac_encoded_adata.h5ad".strip("_")) |
|
|
259 |
) |
|
|
260 |
|
|
|
261 |
|
|
|
262 |
def infer_reader(fname: str, mode: str = "atac") -> Callable: |
|
|
263 |
"""Given a filename, infer the correct reader to use""" |
|
|
264 |
assert mode in ["atac", "rna"], f"Unrecognized mode: {mode}" |
|
|
265 |
if fname.endswith(".h5"): |
|
|
266 |
if mode == "atac": |
|
|
267 |
return functools.partial(utils.sc_read_10x_h5_ft_type, ft_type="Peaks") |
|
|
268 |
else: |
|
|
269 |
return utils.sc_read_10x_h5_ft_type |
|
|
270 |
elif fname.endswith(".h5ad"): |
|
|
271 |
return ad.read_h5ad |
|
|
272 |
else: |
|
|
273 |
raise ValueError(f"Unrecognized extension: {fname}") |
|
|
274 |
|
|
|
275 |
|
|
|
276 |
def build_parser(): |
|
|
277 |
parser = argparse.ArgumentParser( |
|
|
278 |
usage=__doc__, |
|
|
279 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
280 |
) |
|
|
281 |
parser.add_argument( |
|
|
282 |
"--checkpoint", |
|
|
283 |
type=str, |
|
|
284 |
nargs="*", |
|
|
285 |
required=False, |
|
|
286 |
default=[ |
|
|
287 |
os.path.join(model_utils.MODEL_CACHE_DIR, "cv_logsplit_01_model_only") |
|
|
288 |
], |
|
|
289 |
help="Checkpoint directory to load model from. If not given, automatically download and use a human pretrained model", |
|
|
290 |
) |
|
|
291 |
parser.add_argument("--prefix", type=str, default="net_", help="Checkpoint prefix") |
|
|
292 |
parser.add_argument("--data", required=True, nargs="*", help="Data files") |
|
|
293 |
parser.add_argument( |
|
|
294 |
"--dataname", default="", help="Name of dataset to include in plot titles" |
|
|
295 |
) |
|
|
296 |
parser.add_argument( |
|
|
297 |
"--outdir", type=str, required=True, help="Output directory for files and plots" |
|
|
298 |
) |
|
|
299 |
parser.add_argument( |
|
|
300 |
"--genes", |
|
|
301 |
type=str, |
|
|
302 |
default="", |
|
|
303 |
help="Genes that the model uses (inferred based on checkpoint dir if not given)", |
|
|
304 |
) |
|
|
305 |
parser.add_argument( |
|
|
306 |
"--bins", |
|
|
307 |
type=str, |
|
|
308 |
default="", |
|
|
309 |
help="ATAC bins that the model uses (inferred based on checkpoint dir if not given)", |
|
|
310 |
) |
|
|
311 |
parser.add_argument( |
|
|
312 |
"--liftHg19toHg38", |
|
|
313 |
action="store_true", |
|
|
314 |
help="Liftover input ATAC bins from hg19 to hg38", |
|
|
315 |
) |
|
|
316 |
parser.add_argument("--device", type=str, default="0", help="Device to use") |
|
|
317 |
parser.add_argument( |
|
|
318 |
"--ext", |
|
|
319 |
type=str, |
|
|
320 |
default="pdf", |
|
|
321 |
choices=["pdf", "png", "jpg"], |
|
|
322 |
help="File format to use for plotting", |
|
|
323 |
) |
|
|
324 |
parser.add_argument( |
|
|
325 |
"--noplot", action="store_true", help="Disable plotting, writing output only" |
|
|
326 |
) |
|
|
327 |
parser.add_argument( |
|
|
328 |
"--transonly", |
|
|
329 |
action="store_true", |
|
|
330 |
help="Disable doing same-modality inference", |
|
|
331 |
) |
|
|
332 |
parser.add_argument( |
|
|
333 |
"--skiprnasource", action="store_true", help="Skip analysis starting from RNA" |
|
|
334 |
) |
|
|
335 |
parser.add_argument( |
|
|
336 |
"--skipatacsource", action="store_true", help="Skip analysis starting from ATAC" |
|
|
337 |
) |
|
|
338 |
parser.add_argument( |
|
|
339 |
"--nofilter", |
|
|
340 |
action="store_true", |
|
|
341 |
help="Whether or not to perform filtering (note that we always discard cells with no expressed genes)", |
|
|
342 |
) |
|
|
343 |
return parser |
|
|
344 |
|
|
|
345 |
|
|
|
346 |
def load_rna_files_for_eval( |
|
|
347 |
data, checkpoint: str, rna_genes_list_fname: str = "", no_filter: bool = False |
|
|
348 |
): |
|
|
349 |
""" """ |
|
|
350 |
if not rna_genes_list_fname: |
|
|
351 |
rna_genes_list_fname = os.path.join(checkpoint, "rna_genes.txt") |
|
|
352 |
assert os.path.isfile( |
|
|
353 |
rna_genes_list_fname |
|
|
354 |
), f"Cannot find RNA genes file: {rna_genes_list_fname}" |
|
|
355 |
rna_genes = utils.read_delimited_file(rna_genes_list_fname) |
|
|
356 |
rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS) |
|
|
357 |
if no_filter: |
|
|
358 |
rna_data_kwargs = { |
|
|
359 |
k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_") |
|
|
360 |
} |
|
|
361 |
# Always discard cells with no expressed genes |
|
|
362 |
rna_data_kwargs["filt_cell_min_genes"] = 1 |
|
|
363 |
rna_data_kwargs["fname"] = data |
|
|
364 |
reader_func = functools.partial( |
|
|
365 |
utils.sc_read_multi_files, |
|
|
366 |
reader=lambda x: sc_data_loaders.repool_genes( |
|
|
367 |
utils.get_ad_reader(x, ft_type="Gene Expression")(x), rna_genes |
|
|
368 |
), |
|
|
369 |
) |
|
|
370 |
rna_data_kwargs["reader"] = reader_func |
|
|
371 |
try: |
|
|
372 |
logging.info(f"Building RNA dataset with parameters: {rna_data_kwargs}") |
|
|
373 |
sc_rna_full_dataset = sc_data_loaders.SingleCellDataset( |
|
|
374 |
mode="skip", |
|
|
375 |
**rna_data_kwargs, |
|
|
376 |
) |
|
|
377 |
assert all( |
|
|
378 |
[x == y for x, y in zip(rna_genes, sc_rna_full_dataset.data_raw.var_names)] |
|
|
379 |
), "Mismatched genes" |
|
|
380 |
_temp = sc_rna_full_dataset[0] # Try that query works |
|
|
381 |
# adata_utils.find_marker_genes(sc_rna_full_dataset.data_raw, n_genes=25) |
|
|
382 |
# marker_genes = adata_utils.flatten_marker_genes( |
|
|
383 |
# sc_rna_full_dataset.data_raw.uns["rank_genes_leiden"] |
|
|
384 |
# ) |
|
|
385 |
marker_genes = [] |
|
|
386 |
# Write out the truth |
|
|
387 |
except (AssertionError, IndexError) as e: |
|
|
388 |
logging.warning(f"Error when reading RNA gene expression data from {data}: {e}") |
|
|
389 |
logging.warning("Ignoring RNA data") |
|
|
390 |
# Update length later |
|
|
391 |
sc_rna_full_dataset = sc_data_loaders.DummyDataset( |
|
|
392 |
shape=len(rna_genes), length=-1 |
|
|
393 |
) |
|
|
394 |
marker_genes = [] |
|
|
395 |
return sc_rna_full_dataset, rna_genes, marker_genes |
|
|
396 |
|
|
|
397 |
|
|
|
398 |
def load_atac_files_for_eval( |
|
|
399 |
data: List[str], |
|
|
400 |
checkpoint: str, |
|
|
401 |
atac_bins_list_fname: str = "", |
|
|
402 |
lift_hg19_to_hg39: bool = False, |
|
|
403 |
predefined_split=None, |
|
|
404 |
): |
|
|
405 |
"""Load the ATAC files for evaluation""" |
|
|
406 |
if not atac_bins_list_fname: |
|
|
407 |
atac_bins_list_fname = os.path.join(checkpoint, "atac_bins.txt") |
|
|
408 |
logging.info(f"Auto-set atac bins fname to {atac_bins_list_fname}") |
|
|
409 |
assert os.path.isfile( |
|
|
410 |
atac_bins_list_fname |
|
|
411 |
), f"Cannot find ATAC bins file: {atac_bins_list_fname}" |
|
|
412 |
atac_bins = utils.read_delimited_file( |
|
|
413 |
atac_bins_list_fname |
|
|
414 |
) # These are the bins we are using (i.e. the bins the model was trained on) |
|
|
415 |
atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS) |
|
|
416 |
atac_data_kwargs["fname"] = data |
|
|
417 |
atac_data_kwargs["cluster_res"] = 0 # Disable clustering |
|
|
418 |
filt_atac_keys = [k for k in atac_data_kwargs.keys() if k.startswith("filt")] |
|
|
419 |
for k in filt_atac_keys: # Reset filtering |
|
|
420 |
atac_data_kwargs[k] = None |
|
|
421 |
atac_data_kwargs["pool_genomic_interval"] = atac_bins |
|
|
422 |
if not lift_hg19_to_hg39: |
|
|
423 |
atac_data_kwargs["reader"] = functools.partial( |
|
|
424 |
utils.sc_read_multi_files, |
|
|
425 |
reader=lambda x: sc_data_loaders.repool_atac_bins( |
|
|
426 |
infer_reader(data[0], mode="atac")(x), |
|
|
427 |
atac_bins, |
|
|
428 |
), |
|
|
429 |
) |
|
|
430 |
else: # Requires liftover |
|
|
431 |
# Read, liftover, then repool |
|
|
432 |
atac_data_kwargs["reader"] = functools.partial( |
|
|
433 |
utils.sc_read_multi_files, |
|
|
434 |
reader=lambda x: sc_data_loaders.repool_atac_bins( |
|
|
435 |
sc_data_loaders.liftover_atac_adata( |
|
|
436 |
# utils.sc_read_10x_h5_ft_type(x, "Peaks") |
|
|
437 |
infer_reader(data[0], mode="atac")(x) |
|
|
438 |
), |
|
|
439 |
atac_bins, |
|
|
440 |
), |
|
|
441 |
) |
|
|
442 |
|
|
|
443 |
try: |
|
|
444 |
sc_atac_full_dataset = sc_data_loaders.SingleCellDataset( |
|
|
445 |
mode="skip", |
|
|
446 |
predefined_split=predefined_split if predefined_split else None, |
|
|
447 |
**atac_data_kwargs, |
|
|
448 |
) |
|
|
449 |
_temp = sc_atac_full_dataset[0] # Try that query works |
|
|
450 |
assert all( |
|
|
451 |
[x == y for x, y in zip(atac_bins, sc_atac_full_dataset.data_raw.var_names)] |
|
|
452 |
) |
|
|
453 |
except AssertionError as err: |
|
|
454 |
logging.warning(f"Error when reading ATAC data from {data}: {err}") |
|
|
455 |
logging.warning("Ignoring ATAC data, returning dummy dataset instead") |
|
|
456 |
sc_atac_full_dataset = sc_data_loaders.DummyDataset( |
|
|
457 |
shape=len(atac_bins), length=-1 |
|
|
458 |
) |
|
|
459 |
return sc_atac_full_dataset, atac_bins |
|
|
460 |
|
|
|
461 |
|
|
|
462 |
def main(): |
|
|
463 |
parser = build_parser() |
|
|
464 |
args = parser.parse_args() |
|
|
465 |
logging.info(f"Evaluating: {' '.join(args.data)}") |
|
|
466 |
|
|
|
467 |
global DATASET_NAME |
|
|
468 |
DATASET_NAME = args.dataname |
|
|
469 |
|
|
|
470 |
# Create output directory |
|
|
471 |
if not os.path.isdir(args.outdir): |
|
|
472 |
os.makedirs(args.outdir) |
|
|
473 |
|
|
|
474 |
# Set up logging |
|
|
475 |
logger = logging.getLogger() |
|
|
476 |
fh = logging.FileHandler(os.path.join(args.outdir, "logging.log"), "w") |
|
|
477 |
fh.setLevel(logging.INFO) |
|
|
478 |
logger.addHandler(fh) |
|
|
479 |
|
|
|
480 |
if args.checkpoint[0] == os.path.join( |
|
|
481 |
model_utils.MODEL_CACHE_DIR, "cv_logsplit_01_model_only" |
|
|
482 |
): |
|
|
483 |
_ = model_utils.load_model() # Downloads if not downloaded |
|
|
484 |
(sc_rna_full_dataset, rna_genes, marker_genes,) = load_rna_files_for_eval( |
|
|
485 |
args.data, args.checkpoint[0], args.genes, no_filter=args.nofilter |
|
|
486 |
) |
|
|
487 |
|
|
|
488 |
if hasattr(sc_rna_full_dataset, "size_norm_counts"): |
|
|
489 |
logging.info("Writing truth RNA size normalized counts") |
|
|
490 |
sc_rna_full_dataset.size_norm_counts.write_h5ad( |
|
|
491 |
os.path.join(args.outdir, "truth_rna.h5ad") |
|
|
492 |
) |
|
|
493 |
|
|
|
494 |
sc_atac_full_dataset, atac_bins = load_atac_files_for_eval( |
|
|
495 |
args.data, |
|
|
496 |
args.checkpoint[0], |
|
|
497 |
args.bins, |
|
|
498 |
args.liftHg19toHg38, |
|
|
499 |
sc_rna_full_dataset if hasattr(sc_rna_full_dataset, "data_raw") else None, |
|
|
500 |
) |
|
|
501 |
# Write out the truth |
|
|
502 |
if hasattr(sc_atac_full_dataset, "data_raw"): |
|
|
503 |
logging.info("Writing truth ATAC binary counts") |
|
|
504 |
sc_atac_full_dataset.data_raw.write_h5ad( |
|
|
505 |
os.path.join(args.outdir, "truth_atac.h5ad") |
|
|
506 |
) |
|
|
507 |
|
|
|
508 |
if isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset) and isinstance( |
|
|
509 |
sc_atac_full_dataset, sc_data_loaders.DummyDataset |
|
|
510 |
): |
|
|
511 |
raise ValueError("Cannot proceed with two dummy datasets for both RNA and ATAC") |
|
|
512 |
# Update the RNA counts if we do not actually have RNA data |
|
|
513 |
if isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset) and not isinstance( |
|
|
514 |
sc_atac_full_dataset, sc_data_loaders.DummyDataset |
|
|
515 |
): |
|
|
516 |
sc_rna_full_dataset.length = len(sc_atac_full_dataset) |
|
|
517 |
elif isinstance( |
|
|
518 |
sc_atac_full_dataset, sc_data_loaders.DummyDataset |
|
|
519 |
) and not isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset): |
|
|
520 |
sc_atac_full_dataset.length = len(sc_rna_full_dataset) |
|
|
521 |
|
|
|
522 |
# Build the dual combined dataset |
|
|
523 |
sc_dual_full_dataset = sc_data_loaders.PairedDataset( |
|
|
524 |
sc_rna_full_dataset, |
|
|
525 |
sc_atac_full_dataset, |
|
|
526 |
flat_mode=True, |
|
|
527 |
) |
|
|
528 |
|
|
|
529 |
# Write some basic outputs related to variable and obs names |
|
|
530 |
with open(os.path.join(args.outdir, "rna_genes.txt"), "w") as sink: |
|
|
531 |
sink.write("\n".join(rna_genes) + "\n") |
|
|
532 |
with open(os.path.join(args.outdir, "atac_bins.txt"), "w") as sink: |
|
|
533 |
sink.write("\n".join(atac_bins) + "\n") |
|
|
534 |
with open(os.path.join(args.outdir, "obs_names.txt"), "w") as sink: |
|
|
535 |
sink.write("\n".join(sc_dual_full_dataset.obs_names)) |
|
|
536 |
|
|
|
537 |
for i, ckpt in enumerate(args.checkpoint): |
|
|
538 |
# Dynamically determine the model we are looking at based on name |
|
|
539 |
checkpoint_basename = os.path.basename(ckpt) |
|
|
540 |
if checkpoint_basename.startswith("naive"): |
|
|
541 |
logging.info(f"Inferred model to be naive") |
|
|
542 |
model_class = autoencoders.NaiveSplicedAutoEncoder |
|
|
543 |
else: |
|
|
544 |
logging.info(f"Inferred model to be normal (non-naive)") |
|
|
545 |
model_class = autoencoders.AssymSplicedAutoEncoder |
|
|
546 |
|
|
|
547 |
prefix = "" if len(args.checkpoint) == 1 else f"model_{checkpoint_basename}" |
|
|
548 |
spliced_net = model_utils.load_model( |
|
|
549 |
ckpt, |
|
|
550 |
prefix=args.prefix, |
|
|
551 |
device=args.device, |
|
|
552 |
) |
|
|
553 |
|
|
|
554 |
do_latent_evaluation( |
|
|
555 |
spliced_net=spliced_net, |
|
|
556 |
sc_dual_full_dataset=sc_dual_full_dataset, |
|
|
557 |
outdir=args.outdir, |
|
|
558 |
prefix=prefix, |
|
|
559 |
) |
|
|
560 |
|
|
|
561 |
if ( |
|
|
562 |
isinstance(sc_rna_full_dataset, sc_data_loaders.SingleCellDataset) |
|
|
563 |
and not args.skiprnasource |
|
|
564 |
): |
|
|
565 |
if not args.transonly: |
|
|
566 |
do_evaluation_rna_from_rna( |
|
|
567 |
spliced_net, |
|
|
568 |
sc_dual_full_dataset, |
|
|
569 |
rna_genes, |
|
|
570 |
atac_bins, |
|
|
571 |
args.outdir, |
|
|
572 |
None if args.noplot else args.ext, |
|
|
573 |
marker_genes, |
|
|
574 |
prefix=prefix, |
|
|
575 |
) |
|
|
576 |
do_evaluation_atac_from_rna( |
|
|
577 |
spliced_net, |
|
|
578 |
sc_dual_full_dataset, |
|
|
579 |
rna_genes, |
|
|
580 |
atac_bins, |
|
|
581 |
args.outdir, |
|
|
582 |
None if args.noplot else args.ext, |
|
|
583 |
marker_genes, |
|
|
584 |
prefix=prefix, |
|
|
585 |
) |
|
|
586 |
if ( |
|
|
587 |
isinstance(sc_atac_full_dataset, sc_data_loaders.SingleCellDataset) |
|
|
588 |
and not args.skipatacsource |
|
|
589 |
): |
|
|
590 |
do_evaluation_rna_from_atac( |
|
|
591 |
spliced_net, |
|
|
592 |
sc_dual_full_dataset, |
|
|
593 |
rna_genes, |
|
|
594 |
atac_bins, |
|
|
595 |
args.outdir, |
|
|
596 |
None if args.noplot else args.ext, |
|
|
597 |
marker_genes, |
|
|
598 |
prefix=prefix, |
|
|
599 |
) |
|
|
600 |
if not args.transonly: |
|
|
601 |
do_evaluation_atac_from_atac( |
|
|
602 |
spliced_net, |
|
|
603 |
sc_dual_full_dataset, |
|
|
604 |
rna_genes, |
|
|
605 |
atac_bins, |
|
|
606 |
args.outdir, |
|
|
607 |
None if args.noplot else args.ext, |
|
|
608 |
marker_genes, |
|
|
609 |
prefix=prefix, |
|
|
610 |
) |
|
|
611 |
del spliced_net |
|
|
612 |
|
|
|
613 |
|
|
|
614 |
if __name__ == "__main__": |
|
|
615 |
main() |