|
a |
|
b/bin/predict_protein.py |
|
|
1 |
""" |
|
|
2 |
Script for predicting protein expression |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
import logging |
|
|
8 |
import argparse |
|
|
9 |
|
|
|
10 |
import numpy as np |
|
|
11 |
import pandas as pd |
|
|
12 |
|
|
|
13 |
import torch |
|
|
14 |
import torch.nn as nn |
|
|
15 |
import skorch |
|
|
16 |
|
|
|
17 |
SRC_DIR = os.path.join( |
|
|
18 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel" |
|
|
19 |
) |
|
|
20 |
assert os.path.isdir(SRC_DIR) |
|
|
21 |
sys.path.append(SRC_DIR) |
|
|
22 |
MODELS_DIR = os.path.join(SRC_DIR, "models") |
|
|
23 |
assert os.path.isdir(MODELS_DIR) |
|
|
24 |
sys.path.append(MODELS_DIR) |
|
|
25 |
import sc_data_loaders |
|
|
26 |
import autoencoders |
|
|
27 |
import loss_functions |
|
|
28 |
import model_utils |
|
|
29 |
import protein_utils |
|
|
30 |
import utils |
|
|
31 |
|
|
|
32 |
from predict_model import ( |
|
|
33 |
load_atac_files_for_eval, |
|
|
34 |
load_rna_files_for_eval, |
|
|
35 |
) |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def build_parser(): |
|
|
39 |
"""Build commandline parser""" |
|
|
40 |
parser = argparse.ArgumentParser( |
|
|
41 |
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
42 |
) |
|
|
43 |
parser.add_argument("--babel", type=str, required=True, help="Path to babel model") |
|
|
44 |
parser.add_argument( |
|
|
45 |
"--protmodel", type=str, required=True, help="Path to latent-to-protein model" |
|
|
46 |
) |
|
|
47 |
input_group = parser.add_mutually_exclusive_group(required=True) |
|
|
48 |
input_group.add_argument("--atac", type=str, nargs="*", help="Input ATAC") |
|
|
49 |
input_group.add_argument("--rna", type=str, nargs="*", help="Input RNA") |
|
|
50 |
parser.add_argument( |
|
|
51 |
"--liftHg19toHg38", |
|
|
52 |
action="store_true", |
|
|
53 |
help="Liftover input ATAC bins from hg19 to hg38 (only used for ATAC input)", |
|
|
54 |
) |
|
|
55 |
parser.add_argument( |
|
|
56 |
"-o", "--output", required=True, type=str, help="csv file to output" |
|
|
57 |
) |
|
|
58 |
parser.add_argument("--device", default=1, type=int, help="Device for training") |
|
|
59 |
return parser |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def main(): |
|
|
63 |
parser = build_parser() |
|
|
64 |
args = parser.parse_args() |
|
|
65 |
assert args.output.endswith(".csv") |
|
|
66 |
|
|
|
67 |
# Specify output log file |
|
|
68 |
logger = logging.getLogger() |
|
|
69 |
fh = logging.FileHandler(args.output + ".log") |
|
|
70 |
fh.setLevel(logging.INFO) |
|
|
71 |
logger.addHandler(fh) |
|
|
72 |
|
|
|
73 |
# Log parameters |
|
|
74 |
for arg in vars(args): |
|
|
75 |
logging.info(f"Parameter {arg}: {getattr(args, arg)}") |
|
|
76 |
|
|
|
77 |
# Load the model |
|
|
78 |
babel = model_utils.load_model(args.babel, device=args.device) |
|
|
79 |
# Load in some related files |
|
|
80 |
rna_genes = utils.read_delimited_file(os.path.join(args.babel, "rna_genes.txt")) |
|
|
81 |
atac_bins = utils.read_delimited_file(os.path.join(args.babel, "atac_bins.txt")) |
|
|
82 |
|
|
|
83 |
# Load in the protein accesory model |
|
|
84 |
babel_prot_acc_model = protein_utils.load_protein_accessory_model(args.protmodel) |
|
|
85 |
proteins = utils.read_delimited_file( |
|
|
86 |
os.path.join(args.protmodel, "protein_proteins.txt") |
|
|
87 |
) |
|
|
88 |
|
|
|
89 |
# Get the encoded layer based on input |
|
|
90 |
if args.rna: |
|
|
91 |
( |
|
|
92 |
sc_rna_dset, |
|
|
93 |
_rna_genes, |
|
|
94 |
_marker_genes, |
|
|
95 |
_housekeeper_genes, |
|
|
96 |
) = load_rna_files_for_eval(args.rna, checkpoint=args.babel, no_filter=True) |
|
|
97 |
sc_atac_dummy_dset = sc_data_loaders.DummyDataset( |
|
|
98 |
shape=len(atac_bins), length=len(sc_rna_dset) |
|
|
99 |
) |
|
|
100 |
sc_dual_dataset = sc_data_loaders.PairedDataset( |
|
|
101 |
sc_rna_dset, |
|
|
102 |
sc_atac_dummy_dset, |
|
|
103 |
flat_mode=True, |
|
|
104 |
) |
|
|
105 |
sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset( |
|
|
106 |
sc_dual_dataset, model=babel, input_mode="RNA" |
|
|
107 |
) |
|
|
108 |
cell_barcodes = list(sc_rna_dset.data_raw.obs_names) |
|
|
109 |
encoded = sc_dual_encoded_dataset.encoded |
|
|
110 |
else: |
|
|
111 |
sc_atac_dset, _loaded_atac_bins = load_atac_files_for_eval( |
|
|
112 |
args.atac, checkpoint=args.babel, lift_hg19_to_hg39=args.liftHg19toHg38 |
|
|
113 |
) |
|
|
114 |
sc_rna_dummy_dset = sc_data_loaders.DummyDataset( |
|
|
115 |
shape=len(rna_genes), length=len(sc_atac_dset) |
|
|
116 |
) |
|
|
117 |
sc_dual_dataset = sc_data_loaders.PairedDataset( |
|
|
118 |
sc_rna_dummy_dset, sc_atac_dset, flat_mode=True |
|
|
119 |
) |
|
|
120 |
sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset( |
|
|
121 |
sc_dual_dataset, model=babel, input_mode="ATAC" |
|
|
122 |
) |
|
|
123 |
cell_barcodes = list(sc_atac_dset.data_raw.obs_names) |
|
|
124 |
encoded = sc_dual_encoded_dataset.encoded |
|
|
125 |
|
|
|
126 |
# Array of preds |
|
|
127 |
prot_preds = babel_prot_acc_model.predict(encoded.X) |
|
|
128 |
prot_preds_df = pd.DataFrame( |
|
|
129 |
prot_preds, |
|
|
130 |
index=cell_barcodes, |
|
|
131 |
columns=proteins, |
|
|
132 |
) |
|
|
133 |
prot_preds_df.to_csv(args.output) |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
if __name__ == "__main__": |
|
|
137 |
main() |