Diff of /bin/predict_protein.py [000000] .. [d01132]

Switch to side-by-side view

--- a
+++ b/bin/predict_protein.py
@@ -0,0 +1,137 @@
+"""
+Script for predicting protein expression
+"""
+
+import os
+import sys
+import logging
+import argparse
+
+import numpy as np
+import pandas as pd
+
+import torch
+import torch.nn as nn
+import skorch
+
+SRC_DIR = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel"
+)
+assert os.path.isdir(SRC_DIR)
+sys.path.append(SRC_DIR)
+MODELS_DIR = os.path.join(SRC_DIR, "models")
+assert os.path.isdir(MODELS_DIR)
+sys.path.append(MODELS_DIR)
+import sc_data_loaders
+import autoencoders
+import loss_functions
+import model_utils
+import protein_utils
+import utils
+
+from predict_model import (
+    load_atac_files_for_eval,
+    load_rna_files_for_eval,
+)
+
+
+def build_parser():
+    """Build commandline parser"""
+    parser = argparse.ArgumentParser(
+        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument("--babel", type=str, required=True, help="Path to babel model")
+    parser.add_argument(
+        "--protmodel", type=str, required=True, help="Path to latent-to-protein model"
+    )
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--atac", type=str, nargs="*", help="Input ATAC")
+    input_group.add_argument("--rna", type=str, nargs="*", help="Input RNA")
+    parser.add_argument(
+        "--liftHg19toHg38",
+        action="store_true",
+        help="Liftover input ATAC bins from hg19 to hg38 (only used for ATAC input)",
+    )
+    parser.add_argument(
+        "-o", "--output", required=True, type=str, help="csv file to output"
+    )
+    parser.add_argument("--device", default=1, type=int, help="Device for training")
+    return parser
+
+
+def main():
+    parser = build_parser()
+    args = parser.parse_args()
+    assert args.output.endswith(".csv")
+
+    # Specify output log file
+    logger = logging.getLogger()
+    fh = logging.FileHandler(args.output + ".log")
+    fh.setLevel(logging.INFO)
+    logger.addHandler(fh)
+
+    # Log parameters
+    for arg in vars(args):
+        logging.info(f"Parameter {arg}: {getattr(args, arg)}")
+
+    # Load the model
+    babel = model_utils.load_model(args.babel, device=args.device)
+    # Load in some related files
+    rna_genes = utils.read_delimited_file(os.path.join(args.babel, "rna_genes.txt"))
+    atac_bins = utils.read_delimited_file(os.path.join(args.babel, "atac_bins.txt"))
+
+    # Load in the protein accesory model
+    babel_prot_acc_model = protein_utils.load_protein_accessory_model(args.protmodel)
+    proteins = utils.read_delimited_file(
+        os.path.join(args.protmodel, "protein_proteins.txt")
+    )
+
+    # Get the encoded layer based on input
+    if args.rna:
+        (
+            sc_rna_dset,
+            _rna_genes,
+            _marker_genes,
+            _housekeeper_genes,
+        ) = load_rna_files_for_eval(args.rna, checkpoint=args.babel, no_filter=True)
+        sc_atac_dummy_dset = sc_data_loaders.DummyDataset(
+            shape=len(atac_bins), length=len(sc_rna_dset)
+        )
+        sc_dual_dataset = sc_data_loaders.PairedDataset(
+            sc_rna_dset,
+            sc_atac_dummy_dset,
+            flat_mode=True,
+        )
+        sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset(
+            sc_dual_dataset, model=babel, input_mode="RNA"
+        )
+        cell_barcodes = list(sc_rna_dset.data_raw.obs_names)
+        encoded = sc_dual_encoded_dataset.encoded
+    else:
+        sc_atac_dset, _loaded_atac_bins = load_atac_files_for_eval(
+            args.atac, checkpoint=args.babel, lift_hg19_to_hg39=args.liftHg19toHg38
+        )
+        sc_rna_dummy_dset = sc_data_loaders.DummyDataset(
+            shape=len(rna_genes), length=len(sc_atac_dset)
+        )
+        sc_dual_dataset = sc_data_loaders.PairedDataset(
+            sc_rna_dummy_dset, sc_atac_dset, flat_mode=True
+        )
+        sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset(
+            sc_dual_dataset, model=babel, input_mode="ATAC"
+        )
+        cell_barcodes = list(sc_atac_dset.data_raw.obs_names)
+        encoded = sc_dual_encoded_dataset.encoded
+
+    # Array of preds
+    prot_preds = babel_prot_acc_model.predict(encoded.X)
+    prot_preds_df = pd.DataFrame(
+        prot_preds,
+        index=cell_barcodes,
+        columns=proteins,
+    )
+    prot_preds_df.to_csv(args.output)
+
+
+if __name__ == "__main__":
+    main()