Diff of /bin/predict_protein.py [000000] .. [d01132]

Switch to unified view

a b/bin/predict_protein.py
1
"""
2
Script for predicting protein expression
3
"""
4
5
import os
6
import sys
7
import logging
8
import argparse
9
10
import numpy as np
11
import pandas as pd
12
13
import torch
14
import torch.nn as nn
15
import skorch
16
17
SRC_DIR = os.path.join(
18
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel"
19
)
20
assert os.path.isdir(SRC_DIR)
21
sys.path.append(SRC_DIR)
22
MODELS_DIR = os.path.join(SRC_DIR, "models")
23
assert os.path.isdir(MODELS_DIR)
24
sys.path.append(MODELS_DIR)
25
import sc_data_loaders
26
import autoencoders
27
import loss_functions
28
import model_utils
29
import protein_utils
30
import utils
31
32
from predict_model import (
33
    load_atac_files_for_eval,
34
    load_rna_files_for_eval,
35
)
36
37
38
def build_parser():
39
    """Build commandline parser"""
40
    parser = argparse.ArgumentParser(
41
        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
42
    )
43
    parser.add_argument("--babel", type=str, required=True, help="Path to babel model")
44
    parser.add_argument(
45
        "--protmodel", type=str, required=True, help="Path to latent-to-protein model"
46
    )
47
    input_group = parser.add_mutually_exclusive_group(required=True)
48
    input_group.add_argument("--atac", type=str, nargs="*", help="Input ATAC")
49
    input_group.add_argument("--rna", type=str, nargs="*", help="Input RNA")
50
    parser.add_argument(
51
        "--liftHg19toHg38",
52
        action="store_true",
53
        help="Liftover input ATAC bins from hg19 to hg38 (only used for ATAC input)",
54
    )
55
    parser.add_argument(
56
        "-o", "--output", required=True, type=str, help="csv file to output"
57
    )
58
    parser.add_argument("--device", default=1, type=int, help="Device for training")
59
    return parser
60
61
62
def main():
63
    parser = build_parser()
64
    args = parser.parse_args()
65
    assert args.output.endswith(".csv")
66
67
    # Specify output log file
68
    logger = logging.getLogger()
69
    fh = logging.FileHandler(args.output + ".log")
70
    fh.setLevel(logging.INFO)
71
    logger.addHandler(fh)
72
73
    # Log parameters
74
    for arg in vars(args):
75
        logging.info(f"Parameter {arg}: {getattr(args, arg)}")
76
77
    # Load the model
78
    babel = model_utils.load_model(args.babel, device=args.device)
79
    # Load in some related files
80
    rna_genes = utils.read_delimited_file(os.path.join(args.babel, "rna_genes.txt"))
81
    atac_bins = utils.read_delimited_file(os.path.join(args.babel, "atac_bins.txt"))
82
83
    # Load in the protein accesory model
84
    babel_prot_acc_model = protein_utils.load_protein_accessory_model(args.protmodel)
85
    proteins = utils.read_delimited_file(
86
        os.path.join(args.protmodel, "protein_proteins.txt")
87
    )
88
89
    # Get the encoded layer based on input
90
    if args.rna:
91
        (
92
            sc_rna_dset,
93
            _rna_genes,
94
            _marker_genes,
95
            _housekeeper_genes,
96
        ) = load_rna_files_for_eval(args.rna, checkpoint=args.babel, no_filter=True)
97
        sc_atac_dummy_dset = sc_data_loaders.DummyDataset(
98
            shape=len(atac_bins), length=len(sc_rna_dset)
99
        )
100
        sc_dual_dataset = sc_data_loaders.PairedDataset(
101
            sc_rna_dset,
102
            sc_atac_dummy_dset,
103
            flat_mode=True,
104
        )
105
        sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset(
106
            sc_dual_dataset, model=babel, input_mode="RNA"
107
        )
108
        cell_barcodes = list(sc_rna_dset.data_raw.obs_names)
109
        encoded = sc_dual_encoded_dataset.encoded
110
    else:
111
        sc_atac_dset, _loaded_atac_bins = load_atac_files_for_eval(
112
            args.atac, checkpoint=args.babel, lift_hg19_to_hg39=args.liftHg19toHg38
113
        )
114
        sc_rna_dummy_dset = sc_data_loaders.DummyDataset(
115
            shape=len(rna_genes), length=len(sc_atac_dset)
116
        )
117
        sc_dual_dataset = sc_data_loaders.PairedDataset(
118
            sc_rna_dummy_dset, sc_atac_dset, flat_mode=True
119
        )
120
        sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset(
121
            sc_dual_dataset, model=babel, input_mode="ATAC"
122
        )
123
        cell_barcodes = list(sc_atac_dset.data_raw.obs_names)
124
        encoded = sc_dual_encoded_dataset.encoded
125
126
    # Array of preds
127
    prot_preds = babel_prot_acc_model.predict(encoded.X)
128
    prot_preds_df = pd.DataFrame(
129
        prot_preds,
130
        index=cell_barcodes,
131
        columns=proteins,
132
    )
133
    prot_preds_df.to_csv(args.output)
134
135
136
if __name__ == "__main__":
137
    main()