Diff of /bin/predict_model.py [000000] .. [d01132]

Switch to unified view

a b/bin/predict_model.py
1
"""
2
Code for evaluating a model's ability to generalize to cells that it wasn't trained on.
3
Can only be used to evalute within a species.
4
Generates raw predictions of data modality transfer, and optionally, plots.
5
"""
6
7
import os
8
import sys
9
from typing import *
10
import functools
11
import logging
12
import argparse
13
import copy
14
15
import scipy
16
17
import anndata as ad
18
import scanpy as sc
19
20
import torch
21
import skorch
22
23
SRC_DIR = os.path.join(
24
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
25
    "babel",
26
)
27
assert os.path.isdir(SRC_DIR)
28
sys.path.append(SRC_DIR)
29
import sc_data_loaders
30
import loss_functions
31
import model_utils
32
import plot_utils
33
import adata_utils
34
import utils
35
from models import autoencoders
36
37
DATA_DIR = os.path.join(os.path.dirname(SRC_DIR), "data")
38
assert os.path.isdir(DATA_DIR)
39
40
logging.basicConfig(level=logging.INFO)
41
42
DATASET_NAME = ""
43
44
45
def do_evaluation_rna_from_rna(
46
    spliced_net,
47
    sc_dual_full_dataset,
48
    gene_names: str,
49
    atac_names: str,
50
    outdir: str,
51
    ext: str,
52
    marker_genes: List[str],
53
    prefix: str = "",
54
):
55
    """
56
    Evaluate the given network on the dataset
57
    """
58
    # Do inference and plotting
59
    ### RNA > RNA
60
    logging.info("Inferring RNA from RNA...")
61
    sc_rna_full_preds = spliced_net.translate_1_to_1(sc_dual_full_dataset)
62
    sc_rna_full_preds_anndata = sc.AnnData(
63
        sc_rna_full_preds,
64
        obs=sc_dual_full_dataset.dataset_x.data_raw.obs,
65
    )
66
    sc_rna_full_preds_anndata.var_names = gene_names
67
68
    logging.info("Writing RNA from RNA")
69
    sc_rna_full_preds_anndata.write(
70
        os.path.join(outdir, f"{prefix}_rna_rna_adata.h5ad".strip("_"))
71
    )
72
    if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None:
73
        logging.info("Plotting RNA from RNA")
74
        plot_utils.plot_scatter_with_r(
75
            sc_dual_full_dataset.dataset_x.size_norm_counts.X,
76
            sc_rna_full_preds,
77
            one_to_one=True,
78
            logscale=True,
79
            density_heatmap=True,
80
            title=f"{DATASET_NAME} RNA > RNA".strip(),
81
            fname=os.path.join(outdir, f"{prefix}_rna_rna_log.{ext}".strip("_")),
82
        )
83
84
85
def do_evaluation_atac_from_rna(
86
    spliced_net,
87
    sc_dual_full_dataset,
88
    gene_names: str,
89
    atac_names: str,
90
    outdir: str,
91
    ext: str,
92
    marker_genes: List[str],
93
    prefix: str = "",
94
):
95
    ### RNA > ATAC
96
    logging.info("Inferring ATAC from RNA")
97
    sc_rna_atac_full_preds = spliced_net.translate_1_to_2(sc_dual_full_dataset)
98
    sc_rna_atac_full_preds_anndata = sc.AnnData(
99
        scipy.sparse.csr_matrix(sc_rna_atac_full_preds),
100
        obs=sc_dual_full_dataset.dataset_x.data_raw.obs,
101
    )
102
    sc_rna_atac_full_preds_anndata.var_names = atac_names
103
    logging.info("Writing ATAC from RNA")
104
    sc_rna_atac_full_preds_anndata.write(
105
        os.path.join(outdir, f"{prefix}_rna_atac_adata.h5ad".strip("_"))
106
    )
107
108
    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None:
109
        logging.info("Plotting ATAC from RNA")
110
        plot_utils.plot_auroc(
111
            utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
112
            utils.ensure_arr(sc_rna_atac_full_preds).flatten(),
113
            title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(),
114
            fname=os.path.join(outdir, f"{prefix}_rna_atac_auroc.{ext}".strip("_")),
115
        )
116
        # plot_utils.plot_auprc(
117
        #     utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
118
        #     utils.ensure_arr(sc_rna_atac_full_preds),
119
        #     title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(),
120
        #     fname=os.path.join(outdir, f"{prefix}_rna_atac_auprc.{ext}".strip("_")),
121
        # )
122
123
124
def do_evaluation_atac_from_atac(
125
    spliced_net,
126
    sc_dual_full_dataset,
127
    gene_names: str,
128
    atac_names: str,
129
    outdir: str,
130
    ext: str,
131
    marker_genes: List[str],
132
    prefix: str = "",
133
):
134
    ### ATAC > ATAC
135
    logging.info("Inferring ATAC from ATAC")
136
    sc_atac_full_preds = spliced_net.translate_2_to_2(sc_dual_full_dataset)
137
    sc_atac_full_preds_anndata = sc.AnnData(
138
        sc_atac_full_preds,
139
        obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
140
    )
141
    sc_atac_full_preds_anndata.var_names = atac_names
142
    logging.info("Writing ATAC from ATAC")
143
144
    # Infer marker bins
145
    # logging.info("Getting marker bins for ATAC from ATAC")
146
    # plot_utils.preprocess_anndata(sc_atac_full_preds_anndata)
147
    # adata_utils.find_marker_genes(sc_atac_full_preds_anndata)
148
    # inferred_marker_bins = adata_utils.flatten_marker_genes(
149
    #     sc_atac_full_preds_anndata.uns["rank_genes_leiden"]
150
    # )
151
    # logging.info(f"Found {len(inferred_marker_bins)} marker bins for ATAC from ATAC")
152
    # with open(
153
    #     os.path.join(outdir, f"{prefix}_atac_atac_marker_bins.txt".strip("_")), "w"
154
    # ) as sink:
155
    #     sink.write("\n".join(inferred_marker_bins) + "\n")
156
157
    sc_atac_full_preds_anndata.write(
158
        os.path.join(outdir, f"{prefix}_atac_atac_adata.h5ad".strip("_"))
159
    )
160
    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None:
161
        logging.info("Plotting ATAC from ATAC")
162
        plot_utils.plot_auroc(
163
            utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
164
            utils.ensure_arr(sc_atac_full_preds).flatten(),
165
            title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(),
166
            fname=os.path.join(outdir, f"{prefix}_atac_atac_auroc.{ext}".strip("_")),
167
        )
168
        # plot_utils.plot_auprc(
169
        #     utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
170
        #     utils.ensure_arr(sc_atac_full_preds).flatten(),
171
        #     title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(),
172
        #     fname=os.path.join(outdir, f"{prefix}_atac_atac_auprc.{ext}".strip("_")),
173
        # )
174
175
    # Remove some objects to free memory
176
    del sc_atac_full_preds
177
    del sc_atac_full_preds_anndata
178
179
180
def do_evaluation_rna_from_atac(
181
    spliced_net,
182
    sc_dual_full_dataset,
183
    gene_names: str,
184
    atac_names: str,
185
    outdir: str,
186
    ext: str,
187
    marker_genes: List[str],
188
    prefix: str = "",
189
):
190
    ### ATAC > RNA
191
    logging.info("Inferring RNA from ATAC")
192
    sc_atac_rna_full_preds = spliced_net.translate_2_to_1(sc_dual_full_dataset)
193
    # Seurat expects everything to be sparse
194
    # https://github.com/satijalab/seurat/issues/2228
195
    sc_atac_rna_full_preds_anndata = sc.AnnData(
196
        sc_atac_rna_full_preds,
197
        obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
198
    )
199
    sc_atac_rna_full_preds_anndata.var_names = gene_names
200
    logging.info("Writing RNA from ATAC")
201
202
    # Seurat also expects the raw attribute to be populated
203
    sc_atac_rna_full_preds_anndata.raw = sc_atac_rna_full_preds_anndata.copy()
204
    sc_atac_rna_full_preds_anndata.write(
205
        os.path.join(outdir, f"{prefix}_atac_rna_adata.h5ad".strip("_"))
206
    )
207
    # sc_atac_rna_full_preds_anndata.write_csvs(
208
    #     os.path.join(outdir, f"{prefix}_atac_rna_constituent_csv".strip("_")),
209
    #     skip_data=False,
210
    # )
211
    # sc_atac_rna_full_preds_anndata.to_df().to_csv(
212
    #     os.path.join(outdir, f"{prefix}_atac_rna_table.csv".strip("_"))
213
    # )
214
215
    # If there eixsts a ground truth RNA, do RNA plotting
216
    if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None:
217
        logging.info("Plotting RNA from ATAC")
218
        plot_utils.plot_scatter_with_r(
219
            sc_dual_full_dataset.dataset_x.size_norm_counts.X,
220
            sc_atac_rna_full_preds,
221
            one_to_one=True,
222
            logscale=True,
223
            density_heatmap=True,
224
            title=f"{DATASET_NAME} ATAC > RNA".strip(),
225
            fname=os.path.join(outdir, f"{prefix}_atac_rna_log.{ext}".strip("_")),
226
        )
227
228
    # Remove objects to free memory
229
    del sc_atac_rna_full_preds
230
    del sc_atac_rna_full_preds_anndata
231
232
233
def do_latent_evaluation(
234
    spliced_net, sc_dual_full_dataset, outdir: str, prefix: str = ""
235
):
236
    """
237
    Pull out latent space and write to file
238
    """
239
    logging.info("Inferring latent representations")
240
    encoded_from_rna, encoded_from_atac = spliced_net.get_encoded_layer(
241
        sc_dual_full_dataset
242
    )
243
244
    if hasattr(sc_dual_full_dataset.dataset_x, "data_raw"):
245
        encoded_from_rna_adata = sc.AnnData(
246
            encoded_from_rna,
247
            obs=sc_dual_full_dataset.dataset_x.data_raw.obs.copy(deep=True),
248
        )
249
        encoded_from_rna_adata.write(
250
            os.path.join(outdir, f"{prefix}_rna_encoded_adata.h5ad".strip("_"))
251
        )
252
    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw"):
253
        encoded_from_atac_adata = sc.AnnData(
254
            encoded_from_atac,
255
            obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
256
        )
257
        encoded_from_atac_adata.write(
258
            os.path.join(outdir, f"{prefix}_atac_encoded_adata.h5ad".strip("_"))
259
        )
260
261
262
def infer_reader(fname: str, mode: str = "atac") -> Callable:
263
    """Given a filename, infer the correct reader to use"""
264
    assert mode in ["atac", "rna"], f"Unrecognized mode: {mode}"
265
    if fname.endswith(".h5"):
266
        if mode == "atac":
267
            return functools.partial(utils.sc_read_10x_h5_ft_type, ft_type="Peaks")
268
        else:
269
            return utils.sc_read_10x_h5_ft_type
270
    elif fname.endswith(".h5ad"):
271
        return ad.read_h5ad
272
    else:
273
        raise ValueError(f"Unrecognized extension: {fname}")
274
275
276
def build_parser():
277
    parser = argparse.ArgumentParser(
278
        usage=__doc__,
279
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
280
    )
281
    parser.add_argument(
282
        "--checkpoint",
283
        type=str,
284
        nargs="*",
285
        required=False,
286
        default=[
287
            os.path.join(model_utils.MODEL_CACHE_DIR, "cv_logsplit_01_model_only")
288
        ],
289
        help="Checkpoint directory to load model from. If not given, automatically download and use a human pretrained model",
290
    )
291
    parser.add_argument("--prefix", type=str, default="net_", help="Checkpoint prefix")
292
    parser.add_argument("--data", required=True, nargs="*", help="Data files")
293
    parser.add_argument(
294
        "--dataname", default="", help="Name of dataset to include in plot titles"
295
    )
296
    parser.add_argument(
297
        "--outdir", type=str, required=True, help="Output directory for files and plots"
298
    )
299
    parser.add_argument(
300
        "--genes",
301
        type=str,
302
        default="",
303
        help="Genes that the model uses (inferred based on checkpoint dir if not given)",
304
    )
305
    parser.add_argument(
306
        "--bins",
307
        type=str,
308
        default="",
309
        help="ATAC bins that the model uses (inferred based on checkpoint dir if not given)",
310
    )
311
    parser.add_argument(
312
        "--liftHg19toHg38",
313
        action="store_true",
314
        help="Liftover input ATAC bins from hg19 to hg38",
315
    )
316
    parser.add_argument("--device", type=str, default="0", help="Device to use")
317
    parser.add_argument(
318
        "--ext",
319
        type=str,
320
        default="pdf",
321
        choices=["pdf", "png", "jpg"],
322
        help="File format to use for plotting",
323
    )
324
    parser.add_argument(
325
        "--noplot", action="store_true", help="Disable plotting, writing output only"
326
    )
327
    parser.add_argument(
328
        "--transonly",
329
        action="store_true",
330
        help="Disable doing same-modality inference",
331
    )
332
    parser.add_argument(
333
        "--skiprnasource", action="store_true", help="Skip analysis starting from RNA"
334
    )
335
    parser.add_argument(
336
        "--skipatacsource", action="store_true", help="Skip analysis starting from ATAC"
337
    )
338
    parser.add_argument(
339
        "--nofilter",
340
        action="store_true",
341
        help="Whether or not to perform filtering (note that we always discard cells with no expressed genes)",
342
    )
343
    return parser
344
345
346
def load_rna_files_for_eval(
347
    data, checkpoint: str, rna_genes_list_fname: str = "", no_filter: bool = False
348
):
349
    """ """
350
    if not rna_genes_list_fname:
351
        rna_genes_list_fname = os.path.join(checkpoint, "rna_genes.txt")
352
    assert os.path.isfile(
353
        rna_genes_list_fname
354
    ), f"Cannot find RNA genes file: {rna_genes_list_fname}"
355
    rna_genes = utils.read_delimited_file(rna_genes_list_fname)
356
    rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
357
    if no_filter:
358
        rna_data_kwargs = {
359
            k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_")
360
        }
361
        # Always discard cells with no expressed genes
362
        rna_data_kwargs["filt_cell_min_genes"] = 1
363
    rna_data_kwargs["fname"] = data
364
    reader_func = functools.partial(
365
        utils.sc_read_multi_files,
366
        reader=lambda x: sc_data_loaders.repool_genes(
367
            utils.get_ad_reader(x, ft_type="Gene Expression")(x), rna_genes
368
        ),
369
    )
370
    rna_data_kwargs["reader"] = reader_func
371
    try:
372
        logging.info(f"Building RNA dataset with parameters: {rna_data_kwargs}")
373
        sc_rna_full_dataset = sc_data_loaders.SingleCellDataset(
374
            mode="skip",
375
            **rna_data_kwargs,
376
        )
377
        assert all(
378
            [x == y for x, y in zip(rna_genes, sc_rna_full_dataset.data_raw.var_names)]
379
        ), "Mismatched genes"
380
        _temp = sc_rna_full_dataset[0]  # Try that query works
381
        # adata_utils.find_marker_genes(sc_rna_full_dataset.data_raw, n_genes=25)
382
        # marker_genes = adata_utils.flatten_marker_genes(
383
        #     sc_rna_full_dataset.data_raw.uns["rank_genes_leiden"]
384
        # )
385
        marker_genes = []
386
        # Write out the truth
387
    except (AssertionError, IndexError) as e:
388
        logging.warning(f"Error when reading RNA gene expression data from {data}: {e}")
389
        logging.warning("Ignoring RNA data")
390
        # Update length later
391
        sc_rna_full_dataset = sc_data_loaders.DummyDataset(
392
            shape=len(rna_genes), length=-1
393
        )
394
        marker_genes = []
395
    return sc_rna_full_dataset, rna_genes, marker_genes
396
397
398
def load_atac_files_for_eval(
399
    data: List[str],
400
    checkpoint: str,
401
    atac_bins_list_fname: str = "",
402
    lift_hg19_to_hg39: bool = False,
403
    predefined_split=None,
404
):
405
    """Load the ATAC files for evaluation"""
406
    if not atac_bins_list_fname:
407
        atac_bins_list_fname = os.path.join(checkpoint, "atac_bins.txt")
408
        logging.info(f"Auto-set atac bins fname to {atac_bins_list_fname}")
409
    assert os.path.isfile(
410
        atac_bins_list_fname
411
    ), f"Cannot find ATAC bins file: {atac_bins_list_fname}"
412
    atac_bins = utils.read_delimited_file(
413
        atac_bins_list_fname
414
    )  # These are the bins we are using (i.e. the bins the model was trained on)
415
    atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS)
416
    atac_data_kwargs["fname"] = data
417
    atac_data_kwargs["cluster_res"] = 0  # Disable clustering
418
    filt_atac_keys = [k for k in atac_data_kwargs.keys() if k.startswith("filt")]
419
    for k in filt_atac_keys:  # Reset filtering
420
        atac_data_kwargs[k] = None
421
    atac_data_kwargs["pool_genomic_interval"] = atac_bins
422
    if not lift_hg19_to_hg39:
423
        atac_data_kwargs["reader"] = functools.partial(
424
            utils.sc_read_multi_files,
425
            reader=lambda x: sc_data_loaders.repool_atac_bins(
426
                infer_reader(data[0], mode="atac")(x),
427
                atac_bins,
428
            ),
429
        )
430
    else:  # Requires liftover
431
        # Read, liftover, then repool
432
        atac_data_kwargs["reader"] = functools.partial(
433
            utils.sc_read_multi_files,
434
            reader=lambda x: sc_data_loaders.repool_atac_bins(
435
                sc_data_loaders.liftover_atac_adata(
436
                    # utils.sc_read_10x_h5_ft_type(x, "Peaks")
437
                    infer_reader(data[0], mode="atac")(x)
438
                ),
439
                atac_bins,
440
            ),
441
        )
442
443
    try:
444
        sc_atac_full_dataset = sc_data_loaders.SingleCellDataset(
445
            mode="skip",
446
            predefined_split=predefined_split if predefined_split else None,
447
            **atac_data_kwargs,
448
        )
449
        _temp = sc_atac_full_dataset[0]  # Try that query works
450
        assert all(
451
            [x == y for x, y in zip(atac_bins, sc_atac_full_dataset.data_raw.var_names)]
452
        )
453
    except AssertionError as err:
454
        logging.warning(f"Error when reading ATAC data from {data}: {err}")
455
        logging.warning("Ignoring ATAC data, returning dummy dataset instead")
456
        sc_atac_full_dataset = sc_data_loaders.DummyDataset(
457
            shape=len(atac_bins), length=-1
458
        )
459
    return sc_atac_full_dataset, atac_bins
460
461
462
def main():
463
    parser = build_parser()
464
    args = parser.parse_args()
465
    logging.info(f"Evaluating: {' '.join(args.data)}")
466
467
    global DATASET_NAME
468
    DATASET_NAME = args.dataname
469
470
    # Create output directory
471
    if not os.path.isdir(args.outdir):
472
        os.makedirs(args.outdir)
473
474
    # Set up logging
475
    logger = logging.getLogger()
476
    fh = logging.FileHandler(os.path.join(args.outdir, "logging.log"), "w")
477
    fh.setLevel(logging.INFO)
478
    logger.addHandler(fh)
479
480
    if args.checkpoint[0] == os.path.join(
481
        model_utils.MODEL_CACHE_DIR, "cv_logsplit_01_model_only"
482
    ):
483
        _ = model_utils.load_model()  # Downloads if not downloaded
484
    (sc_rna_full_dataset, rna_genes, marker_genes,) = load_rna_files_for_eval(
485
        args.data, args.checkpoint[0], args.genes, no_filter=args.nofilter
486
    )
487
488
    if hasattr(sc_rna_full_dataset, "size_norm_counts"):
489
        logging.info("Writing truth RNA size normalized counts")
490
        sc_rna_full_dataset.size_norm_counts.write_h5ad(
491
            os.path.join(args.outdir, "truth_rna.h5ad")
492
        )
493
494
    sc_atac_full_dataset, atac_bins = load_atac_files_for_eval(
495
        args.data,
496
        args.checkpoint[0],
497
        args.bins,
498
        args.liftHg19toHg38,
499
        sc_rna_full_dataset if hasattr(sc_rna_full_dataset, "data_raw") else None,
500
    )
501
    # Write out the truth
502
    if hasattr(sc_atac_full_dataset, "data_raw"):
503
        logging.info("Writing truth ATAC binary counts")
504
        sc_atac_full_dataset.data_raw.write_h5ad(
505
            os.path.join(args.outdir, "truth_atac.h5ad")
506
        )
507
508
    if isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset) and isinstance(
509
        sc_atac_full_dataset, sc_data_loaders.DummyDataset
510
    ):
511
        raise ValueError("Cannot proceed with two dummy datasets for both RNA and ATAC")
512
    # Update the RNA counts if we do not actually have RNA data
513
    if isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset) and not isinstance(
514
        sc_atac_full_dataset, sc_data_loaders.DummyDataset
515
    ):
516
        sc_rna_full_dataset.length = len(sc_atac_full_dataset)
517
    elif isinstance(
518
        sc_atac_full_dataset, sc_data_loaders.DummyDataset
519
    ) and not isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset):
520
        sc_atac_full_dataset.length = len(sc_rna_full_dataset)
521
522
    # Build the dual combined dataset
523
    sc_dual_full_dataset = sc_data_loaders.PairedDataset(
524
        sc_rna_full_dataset,
525
        sc_atac_full_dataset,
526
        flat_mode=True,
527
    )
528
529
    # Write some basic outputs related to variable and obs names
530
    with open(os.path.join(args.outdir, "rna_genes.txt"), "w") as sink:
531
        sink.write("\n".join(rna_genes) + "\n")
532
    with open(os.path.join(args.outdir, "atac_bins.txt"), "w") as sink:
533
        sink.write("\n".join(atac_bins) + "\n")
534
    with open(os.path.join(args.outdir, "obs_names.txt"), "w") as sink:
535
        sink.write("\n".join(sc_dual_full_dataset.obs_names))
536
537
    for i, ckpt in enumerate(args.checkpoint):
538
        # Dynamically determine the model we are looking at based on name
539
        checkpoint_basename = os.path.basename(ckpt)
540
        if checkpoint_basename.startswith("naive"):
541
            logging.info(f"Inferred model to be naive")
542
            model_class = autoencoders.NaiveSplicedAutoEncoder
543
        else:
544
            logging.info(f"Inferred model to be normal (non-naive)")
545
            model_class = autoencoders.AssymSplicedAutoEncoder
546
547
        prefix = "" if len(args.checkpoint) == 1 else f"model_{checkpoint_basename}"
548
        spliced_net = model_utils.load_model(
549
            ckpt,
550
            prefix=args.prefix,
551
            device=args.device,
552
        )
553
554
        do_latent_evaluation(
555
            spliced_net=spliced_net,
556
            sc_dual_full_dataset=sc_dual_full_dataset,
557
            outdir=args.outdir,
558
            prefix=prefix,
559
        )
560
561
        if (
562
            isinstance(sc_rna_full_dataset, sc_data_loaders.SingleCellDataset)
563
            and not args.skiprnasource
564
        ):
565
            if not args.transonly:
566
                do_evaluation_rna_from_rna(
567
                    spliced_net,
568
                    sc_dual_full_dataset,
569
                    rna_genes,
570
                    atac_bins,
571
                    args.outdir,
572
                    None if args.noplot else args.ext,
573
                    marker_genes,
574
                    prefix=prefix,
575
                )
576
            do_evaluation_atac_from_rna(
577
                spliced_net,
578
                sc_dual_full_dataset,
579
                rna_genes,
580
                atac_bins,
581
                args.outdir,
582
                None if args.noplot else args.ext,
583
                marker_genes,
584
                prefix=prefix,
585
            )
586
        if (
587
            isinstance(sc_atac_full_dataset, sc_data_loaders.SingleCellDataset)
588
            and not args.skipatacsource
589
        ):
590
            do_evaluation_rna_from_atac(
591
                spliced_net,
592
                sc_dual_full_dataset,
593
                rna_genes,
594
                atac_bins,
595
                args.outdir,
596
                None if args.noplot else args.ext,
597
                marker_genes,
598
                prefix=prefix,
599
            )
600
            if not args.transonly:
601
                do_evaluation_atac_from_atac(
602
                    spliced_net,
603
                    sc_dual_full_dataset,
604
                    rna_genes,
605
                    atac_bins,
606
                    args.outdir,
607
                    None if args.noplot else args.ext,
608
                    marker_genes,
609
                    prefix=prefix,
610
                )
611
        del spliced_net
612
613
614
if __name__ == "__main__":
615
    main()