|
a |
|
b/bin/train_protein_predictor.py |
|
|
1 |
""" |
|
|
2 |
Script for training a protein predictor |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
import logging |
|
|
8 |
import argparse |
|
|
9 |
import copy |
|
|
10 |
import functools |
|
|
11 |
import itertools |
|
|
12 |
import collections |
|
|
13 |
from typing import * |
|
|
14 |
import json |
|
|
15 |
|
|
|
16 |
import numpy as np |
|
|
17 |
import pandas as pd |
|
|
18 |
from scipy import sparse |
|
|
19 |
import scanpy as sc |
|
|
20 |
import anndata as ad |
|
|
21 |
|
|
|
22 |
import matplotlib.pyplot as plt |
|
|
23 |
|
|
|
24 |
import torch |
|
|
25 |
import torch.nn as nn |
|
|
26 |
import torch.nn.functional as F |
|
|
27 |
import skorch |
|
|
28 |
import skorch.helper |
|
|
29 |
|
|
|
30 |
torch.backends.cudnn.deterministic = True # For reproducibility |
|
|
31 |
torch.backends.cudnn.benchmark = False |
|
|
32 |
|
|
|
33 |
SRC_DIR = os.path.join( |
|
|
34 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel" |
|
|
35 |
) |
|
|
36 |
assert os.path.isdir(SRC_DIR) |
|
|
37 |
sys.path.append(SRC_DIR) |
|
|
38 |
MODELS_DIR = os.path.join(SRC_DIR, "models") |
|
|
39 |
assert os.path.isdir(MODELS_DIR) |
|
|
40 |
sys.path.append(MODELS_DIR) |
|
|
41 |
|
|
|
42 |
import sc_data_loaders |
|
|
43 |
import autoencoders |
|
|
44 |
import loss_functions |
|
|
45 |
import model_utils |
|
|
46 |
from protein_utils import LOSS_DICT, OPTIM_DICT, ACT_DICT |
|
|
47 |
import utils |
|
|
48 |
|
|
|
49 |
from train_model import plot_loss_history |
|
|
50 |
|
|
|
51 |
logging.basicConfig(level=logging.INFO) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def load_rna_files( |
|
|
55 |
rna_counts_fnames: List[str], model_dir: str, transpose: bool = True |
|
|
56 |
) -> ad.AnnData: |
|
|
57 |
"""Load the RNA files in, filling in unmeasured genes as necessary""" |
|
|
58 |
# Find the genes that the model understands |
|
|
59 |
rna_genes_list_fname = os.path.join(model_dir, "rna_genes.txt") |
|
|
60 |
assert os.path.isfile( |
|
|
61 |
rna_genes_list_fname |
|
|
62 |
), f"Cannot find RNA genes file: {rna_genes_list_fname}" |
|
|
63 |
learned_rna_genes = utils.read_delimited_file(rna_genes_list_fname) |
|
|
64 |
assert isinstance(learned_rna_genes, list) |
|
|
65 |
assert utils.is_all_unique( |
|
|
66 |
learned_rna_genes |
|
|
67 |
), "Learned genes list contains duplicates" |
|
|
68 |
|
|
|
69 |
temp_ad = utils.sc_read_multi_files( |
|
|
70 |
rna_counts_fnames, |
|
|
71 |
feature_type="Gene Expression", |
|
|
72 |
transpose=transpose, |
|
|
73 |
join="outer", |
|
|
74 |
) |
|
|
75 |
logging.info(f"Read input RNA files for {temp_ad.shape}") |
|
|
76 |
temp_ad.X = utils.ensure_arr(temp_ad.X) |
|
|
77 |
|
|
|
78 |
# Filter for mouse genes and remove human/mouse prefix |
|
|
79 |
temp_ad.var_names_make_unique() |
|
|
80 |
kept_var_names = [ |
|
|
81 |
vname for vname in temp_ad.var_names if not vname.startswith("MOUSE_") |
|
|
82 |
] |
|
|
83 |
if len(kept_var_names) != temp_ad.n_vars: |
|
|
84 |
temp_ad = temp_ad[:, kept_var_names] |
|
|
85 |
temp_ad.var = pd.DataFrame(index=[v.strip("HUMAN_") for v in kept_var_names]) |
|
|
86 |
|
|
|
87 |
# Expand adata to span all genes |
|
|
88 |
# Initiating as a sparse matrix doesn't allow vectorized building |
|
|
89 |
intersected_genes = set(temp_ad.var_names).intersection(learned_rna_genes) |
|
|
90 |
assert intersected_genes, "No overlap between learned and input genes!" |
|
|
91 |
expanded_mat = np.zeros((temp_ad.n_obs, len(learned_rna_genes))) |
|
|
92 |
skip_count = 0 |
|
|
93 |
for gene in intersected_genes: |
|
|
94 |
dest_idx = learned_rna_genes.index(gene) |
|
|
95 |
src_idx = temp_ad.var_names.get_loc(gene) |
|
|
96 |
if not isinstance(src_idx, int): |
|
|
97 |
logging.warn(f"Got multiple source matches for {gene}, skipping") |
|
|
98 |
skip_count += 1 |
|
|
99 |
continue |
|
|
100 |
v = utils.ensure_arr(temp_ad.X[:, src_idx]).flatten() |
|
|
101 |
expanded_mat[:, dest_idx] = v |
|
|
102 |
if skip_count: |
|
|
103 |
logging.warning( |
|
|
104 |
f"Skipped {skip_count}/{len(intersected_genes)} genes due to multiple matches" |
|
|
105 |
) |
|
|
106 |
expanded_mat = sparse.csr_matrix(expanded_mat) # Compress |
|
|
107 |
retval = ad.AnnData( |
|
|
108 |
expanded_mat, obs=temp_ad.obs, var=pd.DataFrame(index=learned_rna_genes) |
|
|
109 |
) |
|
|
110 |
return retval |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
def build_parser(): |
|
|
114 |
"""Build CLI parser""" |
|
|
115 |
parser = argparse.ArgumentParser( |
|
|
116 |
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
117 |
) |
|
|
118 |
parser.add_argument( |
|
|
119 |
"--rnaCounts", |
|
|
120 |
type=str, |
|
|
121 |
nargs="*", |
|
|
122 |
required=True, |
|
|
123 |
help="file containing raw RNA counts", |
|
|
124 |
) |
|
|
125 |
parser.add_argument( |
|
|
126 |
"--proteinCounts", |
|
|
127 |
type=str, |
|
|
128 |
nargs="*", |
|
|
129 |
required=True, |
|
|
130 |
help="file containing raw protein counts", |
|
|
131 |
) |
|
|
132 |
parser.add_argument( |
|
|
133 |
"--encoder", required=True, type=str, help="Model folder to find encoder" |
|
|
134 |
) |
|
|
135 |
parser.add_argument( |
|
|
136 |
"--outdir", |
|
|
137 |
type=str, |
|
|
138 |
default=os.getcwd(), |
|
|
139 |
help="Output directory for model, defaults to current dir", |
|
|
140 |
) |
|
|
141 |
parser.add_argument( |
|
|
142 |
"--clusterres", |
|
|
143 |
type=float, |
|
|
144 |
default=1.5, |
|
|
145 |
help="Cluster resolution for train/valid/test splits", |
|
|
146 |
) |
|
|
147 |
parser.add_argument( |
|
|
148 |
"--validcluster", type=int, default=0, help="Cluster ID to use as valid cluster" |
|
|
149 |
) |
|
|
150 |
parser.add_argument( |
|
|
151 |
"--testcluster", type=int, default=1, help="Cluster ID to use as test cluster" |
|
|
152 |
) |
|
|
153 |
parser.add_argument( |
|
|
154 |
"--preprocessonly", |
|
|
155 |
action="store_true", |
|
|
156 |
help="Preprocess data only, do not train model", |
|
|
157 |
) |
|
|
158 |
parser.add_argument( |
|
|
159 |
"--act", |
|
|
160 |
type=str, |
|
|
161 |
choices=ACT_DICT.keys(), |
|
|
162 |
default="prelu", |
|
|
163 |
help="Activation function", |
|
|
164 |
) |
|
|
165 |
parser.add_argument( |
|
|
166 |
"--loss", type=str, choices=LOSS_DICT.keys(), default="L1", help="Loss" |
|
|
167 |
) |
|
|
168 |
parser.add_argument( |
|
|
169 |
"--optim", type=str, choices=OPTIM_DICT.keys(), default="adam", help="Optimizer" |
|
|
170 |
) |
|
|
171 |
parser.add_argument("--interdim", type=int, default=64, help="Intermediate dim") |
|
|
172 |
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") |
|
|
173 |
parser.add_argument("--bs", type=int, default=512, help="Batch size") |
|
|
174 |
parser.add_argument( |
|
|
175 |
"--epochs", type=int, default=600, help="Maximum number of epochs to train" |
|
|
176 |
) |
|
|
177 |
parser.add_argument( |
|
|
178 |
"--notrans", |
|
|
179 |
action="store_true", |
|
|
180 |
help="Do not transpose (already in row obs form)", |
|
|
181 |
) |
|
|
182 |
parser.add_argument("--device", default=0, type=int, help="Device for training") |
|
|
183 |
return parser |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
def main(): |
|
|
187 |
"""Train a protein predictor""" |
|
|
188 |
parser = build_parser() |
|
|
189 |
args = parser.parse_args() |
|
|
190 |
|
|
|
191 |
# Create output directory |
|
|
192 |
if not os.path.isdir(args.outdir): |
|
|
193 |
os.makedirs(args.outdir) |
|
|
194 |
|
|
|
195 |
# Specify output log file |
|
|
196 |
logger = logging.getLogger() |
|
|
197 |
fh = logging.FileHandler(os.path.join(args.outdir, "training.log")) |
|
|
198 |
fh.setLevel(logging.INFO) |
|
|
199 |
logger.addHandler(fh) |
|
|
200 |
|
|
|
201 |
# Log parameters |
|
|
202 |
for arg in vars(args): |
|
|
203 |
logging.info(f"Parameter {arg}: {getattr(args, arg)}") |
|
|
204 |
with open(os.path.join(args.outdir, "params.json"), "w") as sink: |
|
|
205 |
json.dump(vars(args), sink, indent=4) |
|
|
206 |
|
|
|
207 |
# Load the model |
|
|
208 |
pretrained_net = model_utils.load_model(args.encoder, device=args.device) |
|
|
209 |
|
|
|
210 |
# Load in some files |
|
|
211 |
rna_genes = utils.read_delimited_file(os.path.join(args.encoder, "rna_genes.txt")) |
|
|
212 |
atac_bins = utils.read_delimited_file(os.path.join(args.encoder, "atac_bins.txt")) |
|
|
213 |
|
|
|
214 |
# Read in the RNA |
|
|
215 |
rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS) |
|
|
216 |
rna_data_kwargs["cluster_res"] = args.clusterres |
|
|
217 |
rna_data_kwargs["fname"] = args.rnaCounts |
|
|
218 |
rna_data_kwargs["reader"] = lambda x: load_rna_files( |
|
|
219 |
x, args.encoder, transpose=not args.notrans |
|
|
220 |
) |
|
|
221 |
|
|
|
222 |
# Construct data folds |
|
|
223 |
full_sc_rna_dataset = sc_data_loaders.SingleCellDataset( |
|
|
224 |
valid_cluster_id=args.validcluster, |
|
|
225 |
test_cluster_id=args.testcluster, |
|
|
226 |
**rna_data_kwargs, |
|
|
227 |
) |
|
|
228 |
full_sc_rna_dataset.data_raw.write_h5ad(os.path.join(args.outdir, "full_rna.h5ad")) |
|
|
229 |
|
|
|
230 |
train_valid_test_dsets = [] |
|
|
231 |
for mode in ["all", "train", "valid", "test"]: |
|
|
232 |
logging.info(f"Constructing {mode} dataset") |
|
|
233 |
sc_rna_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
234 |
full_sc_rna_dataset, split=mode |
|
|
235 |
) |
|
|
236 |
sc_rna_dataset.data_raw.write_h5ad( |
|
|
237 |
os.path.join(args.outdir, f"{mode}_rna.h5ad") |
|
|
238 |
) # Write RNA input |
|
|
239 |
sc_atac_dummy_dataset = sc_data_loaders.DummyDataset( |
|
|
240 |
shape=len(atac_bins), length=len(sc_rna_dataset) |
|
|
241 |
) |
|
|
242 |
# RNA and fake ATAC |
|
|
243 |
sc_dual_dataset = sc_data_loaders.PairedDataset( |
|
|
244 |
sc_rna_dataset, |
|
|
245 |
sc_atac_dummy_dataset, |
|
|
246 |
flat_mode=True, |
|
|
247 |
) |
|
|
248 |
# encoded(RNA) as "x" and RNA + fake ATAC as "y" |
|
|
249 |
sc_rna_encoded_dataset = sc_data_loaders.EncodedDataset( |
|
|
250 |
sc_dual_dataset, model=pretrained_net, input_mode="RNA" |
|
|
251 |
) |
|
|
252 |
sc_rna_encoded_dataset.encoded.write_h5ad( |
|
|
253 |
os.path.join(args.outdir, f"{mode}_encoded.h5ad") |
|
|
254 |
) |
|
|
255 |
sc_protein_dataset = sc_data_loaders.SingleCellProteinDataset( |
|
|
256 |
args.proteinCounts, |
|
|
257 |
obs_names=sc_rna_dataset.obs_names, |
|
|
258 |
transpose=not args.notrans, |
|
|
259 |
) |
|
|
260 |
sc_protein_dataset.data_raw.write_h5ad( |
|
|
261 |
os.path.join(args.outdir, f"{mode}_protein.h5ad") |
|
|
262 |
) # Write protein |
|
|
263 |
# x = 16 dimensional encoded layer, y = 25 dimensional protein array |
|
|
264 |
sc_rna_protein_dataset = sc_data_loaders.SplicedDataset( |
|
|
265 |
sc_rna_encoded_dataset, sc_protein_dataset |
|
|
266 |
) |
|
|
267 |
_temp = sc_rna_protein_dataset[0] # ensure calling works |
|
|
268 |
train_valid_test_dsets.append(sc_rna_protein_dataset) |
|
|
269 |
|
|
|
270 |
# Unpack and do sanity checks |
|
|
271 |
_, sc_rna_prot_train, sc_rna_prot_valid, sc_rna_prot_test = train_valid_test_dsets |
|
|
272 |
x, y, z = sc_rna_prot_train[0], sc_rna_prot_valid[0], sc_rna_prot_test[0] |
|
|
273 |
assert ( |
|
|
274 |
x[0].shape == y[0].shape == z[0].shape |
|
|
275 |
), f"Got mismatched shapes: {x[0].shape} {y[0].shape} {z[0].shape}" |
|
|
276 |
assert ( |
|
|
277 |
x[1].shape == y[1].shape == z[1].shape |
|
|
278 |
), f"Got mismatched shapes: {x[1].shape} {y[1].shape} {z[1].shape}" |
|
|
279 |
|
|
|
280 |
protein_markers = list(sc_protein_dataset.data_raw.var_names) |
|
|
281 |
with open(os.path.join(args.outdir, "protein_proteins.txt"), "w") as sink: |
|
|
282 |
sink.write("\n".join(protein_markers) + "\n") |
|
|
283 |
assert len( |
|
|
284 |
utils.read_delimited_file(os.path.join(args.outdir, "protein_proteins.txt")) |
|
|
285 |
) == len(protein_markers) |
|
|
286 |
logging.info(f"Predicting on {len(protein_markers)} proteins") |
|
|
287 |
|
|
|
288 |
if args.preprocessonly: |
|
|
289 |
return |
|
|
290 |
|
|
|
291 |
protein_decoder_skorch = skorch.NeuralNet( |
|
|
292 |
module=autoencoders.Decoder, |
|
|
293 |
module__num_units=16, |
|
|
294 |
module__intermediate_dim=args.interdim, |
|
|
295 |
module__num_outputs=len(protein_markers), |
|
|
296 |
module__activation=ACT_DICT[args.act], |
|
|
297 |
module__final_activation=nn.Identity(), |
|
|
298 |
# module__final_activation=nn.Linear( |
|
|
299 |
# len(protein_markers), len(protein_markers), bias=True |
|
|
300 |
# ), # Paper uses identity activation instead |
|
|
301 |
lr=args.lr, |
|
|
302 |
criterion=LOSS_DICT[args.loss], # Other works use L1 loss |
|
|
303 |
optimizer=OPTIM_DICT[args.optim], |
|
|
304 |
batch_size=args.bs, |
|
|
305 |
max_epochs=args.epochs, |
|
|
306 |
callbacks=[ |
|
|
307 |
skorch.callbacks.EarlyStopping(patience=15), |
|
|
308 |
skorch.callbacks.LRScheduler( |
|
|
309 |
policy=torch.optim.lr_scheduler.ReduceLROnPlateau, |
|
|
310 |
patience=5, |
|
|
311 |
factor=0.1, |
|
|
312 |
min_lr=1e-6, |
|
|
313 |
# **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS, |
|
|
314 |
), |
|
|
315 |
skorch.callbacks.GradientNormClipping(gradient_clip_value=5), |
|
|
316 |
skorch.callbacks.Checkpoint( |
|
|
317 |
dirname=args.outdir, |
|
|
318 |
fn_prefix="net_", |
|
|
319 |
monitor="valid_loss_best", |
|
|
320 |
), |
|
|
321 |
], |
|
|
322 |
train_split=skorch.helper.predefined_split(sc_rna_prot_valid), |
|
|
323 |
iterator_train__num_workers=8, |
|
|
324 |
iterator_valid__num_workers=8, |
|
|
325 |
device=utils.get_device(args.device), |
|
|
326 |
) |
|
|
327 |
protein_decoder_skorch.fit(sc_rna_prot_train, y=None) |
|
|
328 |
|
|
|
329 |
# Plot the loss history |
|
|
330 |
fig = plot_loss_history( |
|
|
331 |
protein_decoder_skorch.history, os.path.join(args.outdir, "loss.pdf") |
|
|
332 |
) |
|
|
333 |
|
|
|
334 |
|
|
|
335 |
if __name__ == "__main__": |
|
|
336 |
main() |