Diff of /bin/predict_model.py [000000] .. [d01132]

Switch to side-by-side view

--- a
+++ b/bin/predict_model.py
@@ -0,0 +1,615 @@
+"""
+Code for evaluating a model's ability to generalize to cells that it wasn't trained on.
+Can only be used to evalute within a species.
+Generates raw predictions of data modality transfer, and optionally, plots.
+"""
+
+import os
+import sys
+from typing import *
+import functools
+import logging
+import argparse
+import copy
+
+import scipy
+
+import anndata as ad
+import scanpy as sc
+
+import torch
+import skorch
+
+SRC_DIR = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
+    "babel",
+)
+assert os.path.isdir(SRC_DIR)
+sys.path.append(SRC_DIR)
+import sc_data_loaders
+import loss_functions
+import model_utils
+import plot_utils
+import adata_utils
+import utils
+from models import autoencoders
+
+DATA_DIR = os.path.join(os.path.dirname(SRC_DIR), "data")
+assert os.path.isdir(DATA_DIR)
+
+logging.basicConfig(level=logging.INFO)
+
+DATASET_NAME = ""
+
+
+def do_evaluation_rna_from_rna(
+    spliced_net,
+    sc_dual_full_dataset,
+    gene_names: str,
+    atac_names: str,
+    outdir: str,
+    ext: str,
+    marker_genes: List[str],
+    prefix: str = "",
+):
+    """
+    Evaluate the given network on the dataset
+    """
+    # Do inference and plotting
+    ### RNA > RNA
+    logging.info("Inferring RNA from RNA...")
+    sc_rna_full_preds = spliced_net.translate_1_to_1(sc_dual_full_dataset)
+    sc_rna_full_preds_anndata = sc.AnnData(
+        sc_rna_full_preds,
+        obs=sc_dual_full_dataset.dataset_x.data_raw.obs,
+    )
+    sc_rna_full_preds_anndata.var_names = gene_names
+
+    logging.info("Writing RNA from RNA")
+    sc_rna_full_preds_anndata.write(
+        os.path.join(outdir, f"{prefix}_rna_rna_adata.h5ad".strip("_"))
+    )
+    if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None:
+        logging.info("Plotting RNA from RNA")
+        plot_utils.plot_scatter_with_r(
+            sc_dual_full_dataset.dataset_x.size_norm_counts.X,
+            sc_rna_full_preds,
+            one_to_one=True,
+            logscale=True,
+            density_heatmap=True,
+            title=f"{DATASET_NAME} RNA > RNA".strip(),
+            fname=os.path.join(outdir, f"{prefix}_rna_rna_log.{ext}".strip("_")),
+        )
+
+
+def do_evaluation_atac_from_rna(
+    spliced_net,
+    sc_dual_full_dataset,
+    gene_names: str,
+    atac_names: str,
+    outdir: str,
+    ext: str,
+    marker_genes: List[str],
+    prefix: str = "",
+):
+    ### RNA > ATAC
+    logging.info("Inferring ATAC from RNA")
+    sc_rna_atac_full_preds = spliced_net.translate_1_to_2(sc_dual_full_dataset)
+    sc_rna_atac_full_preds_anndata = sc.AnnData(
+        scipy.sparse.csr_matrix(sc_rna_atac_full_preds),
+        obs=sc_dual_full_dataset.dataset_x.data_raw.obs,
+    )
+    sc_rna_atac_full_preds_anndata.var_names = atac_names
+    logging.info("Writing ATAC from RNA")
+    sc_rna_atac_full_preds_anndata.write(
+        os.path.join(outdir, f"{prefix}_rna_atac_adata.h5ad".strip("_"))
+    )
+
+    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None:
+        logging.info("Plotting ATAC from RNA")
+        plot_utils.plot_auroc(
+            utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
+            utils.ensure_arr(sc_rna_atac_full_preds).flatten(),
+            title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(),
+            fname=os.path.join(outdir, f"{prefix}_rna_atac_auroc.{ext}".strip("_")),
+        )
+        # plot_utils.plot_auprc(
+        #     utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
+        #     utils.ensure_arr(sc_rna_atac_full_preds),
+        #     title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(),
+        #     fname=os.path.join(outdir, f"{prefix}_rna_atac_auprc.{ext}".strip("_")),
+        # )
+
+
+def do_evaluation_atac_from_atac(
+    spliced_net,
+    sc_dual_full_dataset,
+    gene_names: str,
+    atac_names: str,
+    outdir: str,
+    ext: str,
+    marker_genes: List[str],
+    prefix: str = "",
+):
+    ### ATAC > ATAC
+    logging.info("Inferring ATAC from ATAC")
+    sc_atac_full_preds = spliced_net.translate_2_to_2(sc_dual_full_dataset)
+    sc_atac_full_preds_anndata = sc.AnnData(
+        sc_atac_full_preds,
+        obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
+    )
+    sc_atac_full_preds_anndata.var_names = atac_names
+    logging.info("Writing ATAC from ATAC")
+
+    # Infer marker bins
+    # logging.info("Getting marker bins for ATAC from ATAC")
+    # plot_utils.preprocess_anndata(sc_atac_full_preds_anndata)
+    # adata_utils.find_marker_genes(sc_atac_full_preds_anndata)
+    # inferred_marker_bins = adata_utils.flatten_marker_genes(
+    #     sc_atac_full_preds_anndata.uns["rank_genes_leiden"]
+    # )
+    # logging.info(f"Found {len(inferred_marker_bins)} marker bins for ATAC from ATAC")
+    # with open(
+    #     os.path.join(outdir, f"{prefix}_atac_atac_marker_bins.txt".strip("_")), "w"
+    # ) as sink:
+    #     sink.write("\n".join(inferred_marker_bins) + "\n")
+
+    sc_atac_full_preds_anndata.write(
+        os.path.join(outdir, f"{prefix}_atac_atac_adata.h5ad".strip("_"))
+    )
+    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None:
+        logging.info("Plotting ATAC from ATAC")
+        plot_utils.plot_auroc(
+            utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
+            utils.ensure_arr(sc_atac_full_preds).flatten(),
+            title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(),
+            fname=os.path.join(outdir, f"{prefix}_atac_atac_auroc.{ext}".strip("_")),
+        )
+        # plot_utils.plot_auprc(
+        #     utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
+        #     utils.ensure_arr(sc_atac_full_preds).flatten(),
+        #     title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(),
+        #     fname=os.path.join(outdir, f"{prefix}_atac_atac_auprc.{ext}".strip("_")),
+        # )
+
+    # Remove some objects to free memory
+    del sc_atac_full_preds
+    del sc_atac_full_preds_anndata
+
+
+def do_evaluation_rna_from_atac(
+    spliced_net,
+    sc_dual_full_dataset,
+    gene_names: str,
+    atac_names: str,
+    outdir: str,
+    ext: str,
+    marker_genes: List[str],
+    prefix: str = "",
+):
+    ### ATAC > RNA
+    logging.info("Inferring RNA from ATAC")
+    sc_atac_rna_full_preds = spliced_net.translate_2_to_1(sc_dual_full_dataset)
+    # Seurat expects everything to be sparse
+    # https://github.com/satijalab/seurat/issues/2228
+    sc_atac_rna_full_preds_anndata = sc.AnnData(
+        sc_atac_rna_full_preds,
+        obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
+    )
+    sc_atac_rna_full_preds_anndata.var_names = gene_names
+    logging.info("Writing RNA from ATAC")
+
+    # Seurat also expects the raw attribute to be populated
+    sc_atac_rna_full_preds_anndata.raw = sc_atac_rna_full_preds_anndata.copy()
+    sc_atac_rna_full_preds_anndata.write(
+        os.path.join(outdir, f"{prefix}_atac_rna_adata.h5ad".strip("_"))
+    )
+    # sc_atac_rna_full_preds_anndata.write_csvs(
+    #     os.path.join(outdir, f"{prefix}_atac_rna_constituent_csv".strip("_")),
+    #     skip_data=False,
+    # )
+    # sc_atac_rna_full_preds_anndata.to_df().to_csv(
+    #     os.path.join(outdir, f"{prefix}_atac_rna_table.csv".strip("_"))
+    # )
+
+    # If there eixsts a ground truth RNA, do RNA plotting
+    if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None:
+        logging.info("Plotting RNA from ATAC")
+        plot_utils.plot_scatter_with_r(
+            sc_dual_full_dataset.dataset_x.size_norm_counts.X,
+            sc_atac_rna_full_preds,
+            one_to_one=True,
+            logscale=True,
+            density_heatmap=True,
+            title=f"{DATASET_NAME} ATAC > RNA".strip(),
+            fname=os.path.join(outdir, f"{prefix}_atac_rna_log.{ext}".strip("_")),
+        )
+
+    # Remove objects to free memory
+    del sc_atac_rna_full_preds
+    del sc_atac_rna_full_preds_anndata
+
+
+def do_latent_evaluation(
+    spliced_net, sc_dual_full_dataset, outdir: str, prefix: str = ""
+):
+    """
+    Pull out latent space and write to file
+    """
+    logging.info("Inferring latent representations")
+    encoded_from_rna, encoded_from_atac = spliced_net.get_encoded_layer(
+        sc_dual_full_dataset
+    )
+
+    if hasattr(sc_dual_full_dataset.dataset_x, "data_raw"):
+        encoded_from_rna_adata = sc.AnnData(
+            encoded_from_rna,
+            obs=sc_dual_full_dataset.dataset_x.data_raw.obs.copy(deep=True),
+        )
+        encoded_from_rna_adata.write(
+            os.path.join(outdir, f"{prefix}_rna_encoded_adata.h5ad".strip("_"))
+        )
+    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw"):
+        encoded_from_atac_adata = sc.AnnData(
+            encoded_from_atac,
+            obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
+        )
+        encoded_from_atac_adata.write(
+            os.path.join(outdir, f"{prefix}_atac_encoded_adata.h5ad".strip("_"))
+        )
+
+
+def infer_reader(fname: str, mode: str = "atac") -> Callable:
+    """Given a filename, infer the correct reader to use"""
+    assert mode in ["atac", "rna"], f"Unrecognized mode: {mode}"
+    if fname.endswith(".h5"):
+        if mode == "atac":
+            return functools.partial(utils.sc_read_10x_h5_ft_type, ft_type="Peaks")
+        else:
+            return utils.sc_read_10x_h5_ft_type
+    elif fname.endswith(".h5ad"):
+        return ad.read_h5ad
+    else:
+        raise ValueError(f"Unrecognized extension: {fname}")
+
+
+def build_parser():
+    parser = argparse.ArgumentParser(
+        usage=__doc__,
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        nargs="*",
+        required=False,
+        default=[
+            os.path.join(model_utils.MODEL_CACHE_DIR, "cv_logsplit_01_model_only")
+        ],
+        help="Checkpoint directory to load model from. If not given, automatically download and use a human pretrained model",
+    )
+    parser.add_argument("--prefix", type=str, default="net_", help="Checkpoint prefix")
+    parser.add_argument("--data", required=True, nargs="*", help="Data files")
+    parser.add_argument(
+        "--dataname", default="", help="Name of dataset to include in plot titles"
+    )
+    parser.add_argument(
+        "--outdir", type=str, required=True, help="Output directory for files and plots"
+    )
+    parser.add_argument(
+        "--genes",
+        type=str,
+        default="",
+        help="Genes that the model uses (inferred based on checkpoint dir if not given)",
+    )
+    parser.add_argument(
+        "--bins",
+        type=str,
+        default="",
+        help="ATAC bins that the model uses (inferred based on checkpoint dir if not given)",
+    )
+    parser.add_argument(
+        "--liftHg19toHg38",
+        action="store_true",
+        help="Liftover input ATAC bins from hg19 to hg38",
+    )
+    parser.add_argument("--device", type=str, default="0", help="Device to use")
+    parser.add_argument(
+        "--ext",
+        type=str,
+        default="pdf",
+        choices=["pdf", "png", "jpg"],
+        help="File format to use for plotting",
+    )
+    parser.add_argument(
+        "--noplot", action="store_true", help="Disable plotting, writing output only"
+    )
+    parser.add_argument(
+        "--transonly",
+        action="store_true",
+        help="Disable doing same-modality inference",
+    )
+    parser.add_argument(
+        "--skiprnasource", action="store_true", help="Skip analysis starting from RNA"
+    )
+    parser.add_argument(
+        "--skipatacsource", action="store_true", help="Skip analysis starting from ATAC"
+    )
+    parser.add_argument(
+        "--nofilter",
+        action="store_true",
+        help="Whether or not to perform filtering (note that we always discard cells with no expressed genes)",
+    )
+    return parser
+
+
+def load_rna_files_for_eval(
+    data, checkpoint: str, rna_genes_list_fname: str = "", no_filter: bool = False
+):
+    """ """
+    if not rna_genes_list_fname:
+        rna_genes_list_fname = os.path.join(checkpoint, "rna_genes.txt")
+    assert os.path.isfile(
+        rna_genes_list_fname
+    ), f"Cannot find RNA genes file: {rna_genes_list_fname}"
+    rna_genes = utils.read_delimited_file(rna_genes_list_fname)
+    rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
+    if no_filter:
+        rna_data_kwargs = {
+            k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_")
+        }
+        # Always discard cells with no expressed genes
+        rna_data_kwargs["filt_cell_min_genes"] = 1
+    rna_data_kwargs["fname"] = data
+    reader_func = functools.partial(
+        utils.sc_read_multi_files,
+        reader=lambda x: sc_data_loaders.repool_genes(
+            utils.get_ad_reader(x, ft_type="Gene Expression")(x), rna_genes
+        ),
+    )
+    rna_data_kwargs["reader"] = reader_func
+    try:
+        logging.info(f"Building RNA dataset with parameters: {rna_data_kwargs}")
+        sc_rna_full_dataset = sc_data_loaders.SingleCellDataset(
+            mode="skip",
+            **rna_data_kwargs,
+        )
+        assert all(
+            [x == y for x, y in zip(rna_genes, sc_rna_full_dataset.data_raw.var_names)]
+        ), "Mismatched genes"
+        _temp = sc_rna_full_dataset[0]  # Try that query works
+        # adata_utils.find_marker_genes(sc_rna_full_dataset.data_raw, n_genes=25)
+        # marker_genes = adata_utils.flatten_marker_genes(
+        #     sc_rna_full_dataset.data_raw.uns["rank_genes_leiden"]
+        # )
+        marker_genes = []
+        # Write out the truth
+    except (AssertionError, IndexError) as e:
+        logging.warning(f"Error when reading RNA gene expression data from {data}: {e}")
+        logging.warning("Ignoring RNA data")
+        # Update length later
+        sc_rna_full_dataset = sc_data_loaders.DummyDataset(
+            shape=len(rna_genes), length=-1
+        )
+        marker_genes = []
+    return sc_rna_full_dataset, rna_genes, marker_genes
+
+
+def load_atac_files_for_eval(
+    data: List[str],
+    checkpoint: str,
+    atac_bins_list_fname: str = "",
+    lift_hg19_to_hg39: bool = False,
+    predefined_split=None,
+):
+    """Load the ATAC files for evaluation"""
+    if not atac_bins_list_fname:
+        atac_bins_list_fname = os.path.join(checkpoint, "atac_bins.txt")
+        logging.info(f"Auto-set atac bins fname to {atac_bins_list_fname}")
+    assert os.path.isfile(
+        atac_bins_list_fname
+    ), f"Cannot find ATAC bins file: {atac_bins_list_fname}"
+    atac_bins = utils.read_delimited_file(
+        atac_bins_list_fname
+    )  # These are the bins we are using (i.e. the bins the model was trained on)
+    atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS)
+    atac_data_kwargs["fname"] = data
+    atac_data_kwargs["cluster_res"] = 0  # Disable clustering
+    filt_atac_keys = [k for k in atac_data_kwargs.keys() if k.startswith("filt")]
+    for k in filt_atac_keys:  # Reset filtering
+        atac_data_kwargs[k] = None
+    atac_data_kwargs["pool_genomic_interval"] = atac_bins
+    if not lift_hg19_to_hg39:
+        atac_data_kwargs["reader"] = functools.partial(
+            utils.sc_read_multi_files,
+            reader=lambda x: sc_data_loaders.repool_atac_bins(
+                infer_reader(data[0], mode="atac")(x),
+                atac_bins,
+            ),
+        )
+    else:  # Requires liftover
+        # Read, liftover, then repool
+        atac_data_kwargs["reader"] = functools.partial(
+            utils.sc_read_multi_files,
+            reader=lambda x: sc_data_loaders.repool_atac_bins(
+                sc_data_loaders.liftover_atac_adata(
+                    # utils.sc_read_10x_h5_ft_type(x, "Peaks")
+                    infer_reader(data[0], mode="atac")(x)
+                ),
+                atac_bins,
+            ),
+        )
+
+    try:
+        sc_atac_full_dataset = sc_data_loaders.SingleCellDataset(
+            mode="skip",
+            predefined_split=predefined_split if predefined_split else None,
+            **atac_data_kwargs,
+        )
+        _temp = sc_atac_full_dataset[0]  # Try that query works
+        assert all(
+            [x == y for x, y in zip(atac_bins, sc_atac_full_dataset.data_raw.var_names)]
+        )
+    except AssertionError as err:
+        logging.warning(f"Error when reading ATAC data from {data}: {err}")
+        logging.warning("Ignoring ATAC data, returning dummy dataset instead")
+        sc_atac_full_dataset = sc_data_loaders.DummyDataset(
+            shape=len(atac_bins), length=-1
+        )
+    return sc_atac_full_dataset, atac_bins
+
+
+def main():
+    parser = build_parser()
+    args = parser.parse_args()
+    logging.info(f"Evaluating: {' '.join(args.data)}")
+
+    global DATASET_NAME
+    DATASET_NAME = args.dataname
+
+    # Create output directory
+    if not os.path.isdir(args.outdir):
+        os.makedirs(args.outdir)
+
+    # Set up logging
+    logger = logging.getLogger()
+    fh = logging.FileHandler(os.path.join(args.outdir, "logging.log"), "w")
+    fh.setLevel(logging.INFO)
+    logger.addHandler(fh)
+
+    if args.checkpoint[0] == os.path.join(
+        model_utils.MODEL_CACHE_DIR, "cv_logsplit_01_model_only"
+    ):
+        _ = model_utils.load_model()  # Downloads if not downloaded
+    (sc_rna_full_dataset, rna_genes, marker_genes,) = load_rna_files_for_eval(
+        args.data, args.checkpoint[0], args.genes, no_filter=args.nofilter
+    )
+
+    if hasattr(sc_rna_full_dataset, "size_norm_counts"):
+        logging.info("Writing truth RNA size normalized counts")
+        sc_rna_full_dataset.size_norm_counts.write_h5ad(
+            os.path.join(args.outdir, "truth_rna.h5ad")
+        )
+
+    sc_atac_full_dataset, atac_bins = load_atac_files_for_eval(
+        args.data,
+        args.checkpoint[0],
+        args.bins,
+        args.liftHg19toHg38,
+        sc_rna_full_dataset if hasattr(sc_rna_full_dataset, "data_raw") else None,
+    )
+    # Write out the truth
+    if hasattr(sc_atac_full_dataset, "data_raw"):
+        logging.info("Writing truth ATAC binary counts")
+        sc_atac_full_dataset.data_raw.write_h5ad(
+            os.path.join(args.outdir, "truth_atac.h5ad")
+        )
+
+    if isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset) and isinstance(
+        sc_atac_full_dataset, sc_data_loaders.DummyDataset
+    ):
+        raise ValueError("Cannot proceed with two dummy datasets for both RNA and ATAC")
+    # Update the RNA counts if we do not actually have RNA data
+    if isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset) and not isinstance(
+        sc_atac_full_dataset, sc_data_loaders.DummyDataset
+    ):
+        sc_rna_full_dataset.length = len(sc_atac_full_dataset)
+    elif isinstance(
+        sc_atac_full_dataset, sc_data_loaders.DummyDataset
+    ) and not isinstance(sc_rna_full_dataset, sc_data_loaders.DummyDataset):
+        sc_atac_full_dataset.length = len(sc_rna_full_dataset)
+
+    # Build the dual combined dataset
+    sc_dual_full_dataset = sc_data_loaders.PairedDataset(
+        sc_rna_full_dataset,
+        sc_atac_full_dataset,
+        flat_mode=True,
+    )
+
+    # Write some basic outputs related to variable and obs names
+    with open(os.path.join(args.outdir, "rna_genes.txt"), "w") as sink:
+        sink.write("\n".join(rna_genes) + "\n")
+    with open(os.path.join(args.outdir, "atac_bins.txt"), "w") as sink:
+        sink.write("\n".join(atac_bins) + "\n")
+    with open(os.path.join(args.outdir, "obs_names.txt"), "w") as sink:
+        sink.write("\n".join(sc_dual_full_dataset.obs_names))
+
+    for i, ckpt in enumerate(args.checkpoint):
+        # Dynamically determine the model we are looking at based on name
+        checkpoint_basename = os.path.basename(ckpt)
+        if checkpoint_basename.startswith("naive"):
+            logging.info(f"Inferred model to be naive")
+            model_class = autoencoders.NaiveSplicedAutoEncoder
+        else:
+            logging.info(f"Inferred model to be normal (non-naive)")
+            model_class = autoencoders.AssymSplicedAutoEncoder
+
+        prefix = "" if len(args.checkpoint) == 1 else f"model_{checkpoint_basename}"
+        spliced_net = model_utils.load_model(
+            ckpt,
+            prefix=args.prefix,
+            device=args.device,
+        )
+
+        do_latent_evaluation(
+            spliced_net=spliced_net,
+            sc_dual_full_dataset=sc_dual_full_dataset,
+            outdir=args.outdir,
+            prefix=prefix,
+        )
+
+        if (
+            isinstance(sc_rna_full_dataset, sc_data_loaders.SingleCellDataset)
+            and not args.skiprnasource
+        ):
+            if not args.transonly:
+                do_evaluation_rna_from_rna(
+                    spliced_net,
+                    sc_dual_full_dataset,
+                    rna_genes,
+                    atac_bins,
+                    args.outdir,
+                    None if args.noplot else args.ext,
+                    marker_genes,
+                    prefix=prefix,
+                )
+            do_evaluation_atac_from_rna(
+                spliced_net,
+                sc_dual_full_dataset,
+                rna_genes,
+                atac_bins,
+                args.outdir,
+                None if args.noplot else args.ext,
+                marker_genes,
+                prefix=prefix,
+            )
+        if (
+            isinstance(sc_atac_full_dataset, sc_data_loaders.SingleCellDataset)
+            and not args.skipatacsource
+        ):
+            do_evaluation_rna_from_atac(
+                spliced_net,
+                sc_dual_full_dataset,
+                rna_genes,
+                atac_bins,
+                args.outdir,
+                None if args.noplot else args.ext,
+                marker_genes,
+                prefix=prefix,
+            )
+            if not args.transonly:
+                do_evaluation_atac_from_atac(
+                    spliced_net,
+                    sc_dual_full_dataset,
+                    rna_genes,
+                    atac_bins,
+                    args.outdir,
+                    None if args.noplot else args.ext,
+                    marker_genes,
+                    prefix=prefix,
+                )
+        del spliced_net
+
+
+if __name__ == "__main__":
+    main()