|
a |
|
b/bin/train_model.py |
|
|
1 |
""" |
|
|
2 |
Code to train a model |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
import logging |
|
|
8 |
import argparse |
|
|
9 |
import copy |
|
|
10 |
import functools |
|
|
11 |
import itertools |
|
|
12 |
|
|
|
13 |
import numpy as np |
|
|
14 |
import pandas as pd |
|
|
15 |
import scipy.spatial |
|
|
16 |
import scanpy as sc |
|
|
17 |
|
|
|
18 |
import matplotlib.pyplot as plt |
|
|
19 |
from skorch.helper import predefined_split |
|
|
20 |
|
|
|
21 |
import torch |
|
|
22 |
import torch.nn as nn |
|
|
23 |
import torch.nn.functional as F |
|
|
24 |
import skorch |
|
|
25 |
import skorch.helper |
|
|
26 |
|
|
|
27 |
torch.backends.cudnn.deterministic = True # For reproducibility |
|
|
28 |
torch.backends.cudnn.benchmark = False |
|
|
29 |
|
|
|
30 |
SRC_DIR = os.path.join( |
|
|
31 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel" |
|
|
32 |
) |
|
|
33 |
assert os.path.isdir(SRC_DIR) |
|
|
34 |
sys.path.append(SRC_DIR) |
|
|
35 |
|
|
|
36 |
MODELS_DIR = os.path.join(SRC_DIR, "models") |
|
|
37 |
assert os.path.isdir(MODELS_DIR) |
|
|
38 |
sys.path.append(MODELS_DIR) |
|
|
39 |
|
|
|
40 |
import sc_data_loaders |
|
|
41 |
import adata_utils |
|
|
42 |
import model_utils |
|
|
43 |
import autoencoders |
|
|
44 |
import loss_functions |
|
|
45 |
import layers |
|
|
46 |
import activations |
|
|
47 |
import plot_utils |
|
|
48 |
import utils |
|
|
49 |
import metrics |
|
|
50 |
import interpretation |
|
|
51 |
|
|
|
52 |
logging.basicConfig(level=logging.INFO) |
|
|
53 |
|
|
|
54 |
OPTIMIZER_DICT = { |
|
|
55 |
"adam": torch.optim.Adam, |
|
|
56 |
"rmsprop": torch.optim.RMSprop, |
|
|
57 |
} |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def build_parser(): |
|
|
61 |
"""Build argument parser""" |
|
|
62 |
parser = argparse.ArgumentParser( |
|
|
63 |
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
64 |
) |
|
|
65 |
input_group = parser.add_mutually_exclusive_group(required=True) |
|
|
66 |
input_group.add_argument( |
|
|
67 |
"--data", "-d", type=str, nargs="*", help="Data files to train on", |
|
|
68 |
) |
|
|
69 |
input_group.add_argument( |
|
|
70 |
"--snareseq", |
|
|
71 |
action="store_true", |
|
|
72 |
help="Data in SNAREseq format, use custom data loading logic for separated RNA ATC files", |
|
|
73 |
) |
|
|
74 |
input_group.add_argument( |
|
|
75 |
"--shareseq", |
|
|
76 |
nargs="+", |
|
|
77 |
type=str, |
|
|
78 |
choices=["lung", "skin", "brain"], |
|
|
79 |
help="Load in the given SHAREseq datasets", |
|
|
80 |
) |
|
|
81 |
parser.add_argument( |
|
|
82 |
"--nofilter", |
|
|
83 |
action="store_true", |
|
|
84 |
help="Whether or not to perform filtering (only applies with --data argument)", |
|
|
85 |
) |
|
|
86 |
parser.add_argument( |
|
|
87 |
"--linear", |
|
|
88 |
action="store_true", |
|
|
89 |
help="Do clustering data splitting in linear instead of log space", |
|
|
90 |
) |
|
|
91 |
parser.add_argument( |
|
|
92 |
"--clustermethod", |
|
|
93 |
type=str, |
|
|
94 |
choices=["leiden", "louvain"], |
|
|
95 |
default="leiden", |
|
|
96 |
help="Clustering method to determine data splits", |
|
|
97 |
) |
|
|
98 |
parser.add_argument( |
|
|
99 |
"--validcluster", type=int, default=0, help="Cluster ID to use as valid cluster" |
|
|
100 |
) |
|
|
101 |
parser.add_argument( |
|
|
102 |
"--testcluster", type=int, default=1, help="Cluster ID to use as test cluster" |
|
|
103 |
) |
|
|
104 |
parser.add_argument( |
|
|
105 |
"--outdir", "-o", required=True, type=str, help="Directory to output to" |
|
|
106 |
) |
|
|
107 |
parser.add_argument( |
|
|
108 |
"--naive", |
|
|
109 |
"-n", |
|
|
110 |
action="store_true", |
|
|
111 |
help="Use a naive model instead of lego model", |
|
|
112 |
) |
|
|
113 |
parser.add_argument( |
|
|
114 |
"--hidden", type=int, nargs="*", default=[16], help="Hidden dimensions" |
|
|
115 |
) |
|
|
116 |
parser.add_argument( |
|
|
117 |
"--pretrain", |
|
|
118 |
type=str, |
|
|
119 |
default="", |
|
|
120 |
help="params.pt file to use to warm initialize the model (instead of starting from scratch)", |
|
|
121 |
) |
|
|
122 |
parser.add_argument( |
|
|
123 |
"--lossweight", |
|
|
124 |
type=float, |
|
|
125 |
nargs="*", |
|
|
126 |
default=[1.33], |
|
|
127 |
help="Relative loss weight", |
|
|
128 |
) |
|
|
129 |
parser.add_argument( |
|
|
130 |
"--optim", |
|
|
131 |
type=str, |
|
|
132 |
default="adam", |
|
|
133 |
choices=OPTIMIZER_DICT.keys(), |
|
|
134 |
help="Optimizer to use", |
|
|
135 |
) |
|
|
136 |
parser.add_argument( |
|
|
137 |
"--lr", "-l", type=float, default=[0.01], nargs="*", help="Learning rate" |
|
|
138 |
) |
|
|
139 |
parser.add_argument( |
|
|
140 |
"--batchsize", "-b", type=int, nargs="*", default=[512], help="Batch size" |
|
|
141 |
) |
|
|
142 |
parser.add_argument( |
|
|
143 |
"--earlystop", type=int, default=25, help="Early stopping after N epochs" |
|
|
144 |
) |
|
|
145 |
parser.add_argument( |
|
|
146 |
"--seed", type=int, nargs="*", default=[182822], help="Random seed to use" |
|
|
147 |
) |
|
|
148 |
parser.add_argument("--device", default=0, type=int, help="Device to train on") |
|
|
149 |
parser.add_argument( |
|
|
150 |
"--ext", |
|
|
151 |
type=str, |
|
|
152 |
choices=["png", "pdf", "jpg"], |
|
|
153 |
default="pdf", |
|
|
154 |
help="Output format for plots", |
|
|
155 |
) |
|
|
156 |
return parser |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
def plot_loss_history(history, fname: str): |
|
|
160 |
"""Create a plot of train valid loss""" |
|
|
161 |
fig, ax = plt.subplots(dpi=300) |
|
|
162 |
ax.plot( |
|
|
163 |
np.arange(len(history)), history[:, "train_loss"], label="Train", |
|
|
164 |
) |
|
|
165 |
ax.plot( |
|
|
166 |
np.arange(len(history)), history[:, "valid_loss"], label="Valid", |
|
|
167 |
) |
|
|
168 |
ax.legend() |
|
|
169 |
ax.set( |
|
|
170 |
xlabel="Epoch", ylabel="Loss", |
|
|
171 |
) |
|
|
172 |
fig.savefig(fname) |
|
|
173 |
return fig |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
def main(): |
|
|
177 |
"""Run the script""" |
|
|
178 |
parser = build_parser() |
|
|
179 |
args = parser.parse_args() |
|
|
180 |
args.outdir = os.path.abspath(args.outdir) |
|
|
181 |
|
|
|
182 |
if not os.path.isdir(os.path.dirname(args.outdir)): |
|
|
183 |
os.makedirs(os.path.dirname(args.outdir)) |
|
|
184 |
|
|
|
185 |
# Specify output log file |
|
|
186 |
logger = logging.getLogger() |
|
|
187 |
fh = logging.FileHandler(f"{args.outdir}_training.log", "w") |
|
|
188 |
fh.setLevel(logging.INFO) |
|
|
189 |
logger.addHandler(fh) |
|
|
190 |
|
|
|
191 |
# Log parameters and pytorch version |
|
|
192 |
if torch.cuda.is_available(): |
|
|
193 |
logging.info(f"PyTorch CUDA version: {torch.version.cuda}") |
|
|
194 |
for arg in vars(args): |
|
|
195 |
logging.info(f"Parameter {arg}: {getattr(args, arg)}") |
|
|
196 |
|
|
|
197 |
# Borrow parameters |
|
|
198 |
logging.info("Reading RNA data") |
|
|
199 |
if args.snareseq: |
|
|
200 |
rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS) |
|
|
201 |
elif args.shareseq: |
|
|
202 |
logging.info(f"Loading in SHAREseq RNA data for: {args.shareseq}") |
|
|
203 |
rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS) |
|
|
204 |
rna_data_kwargs["fname"] = None |
|
|
205 |
rna_data_kwargs["reader"] = None |
|
|
206 |
rna_data_kwargs["cell_info"] = None |
|
|
207 |
rna_data_kwargs["gene_info"] = None |
|
|
208 |
rna_data_kwargs["transpose"] = False |
|
|
209 |
# Load in the datasets |
|
|
210 |
shareseq_rna_adatas = [] |
|
|
211 |
for tissuetype in args.shareseq: |
|
|
212 |
shareseq_rna_adatas.append( |
|
|
213 |
adata_utils.load_shareseq_data( |
|
|
214 |
tissuetype, |
|
|
215 |
dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq", |
|
|
216 |
mode="RNA", |
|
|
217 |
) |
|
|
218 |
) |
|
|
219 |
shareseq_rna_adata = shareseq_rna_adatas[0] |
|
|
220 |
if len(shareseq_rna_adatas) > 1: |
|
|
221 |
shareseq_rna_adata = shareseq_rna_adata.concatenate( |
|
|
222 |
*shareseq_rna_adatas[1:], |
|
|
223 |
join="inner", |
|
|
224 |
batch_key="tissue", |
|
|
225 |
batch_categories=args.shareseq, |
|
|
226 |
) |
|
|
227 |
rna_data_kwargs["raw_adata"] = shareseq_rna_adata |
|
|
228 |
else: |
|
|
229 |
rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS) |
|
|
230 |
rna_data_kwargs["fname"] = args.data |
|
|
231 |
if args.nofilter: |
|
|
232 |
rna_data_kwargs = { |
|
|
233 |
k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_") |
|
|
234 |
} |
|
|
235 |
rna_data_kwargs["data_split_by_cluster_log"] = not args.linear |
|
|
236 |
rna_data_kwargs["data_split_by_cluster"] = args.clustermethod |
|
|
237 |
|
|
|
238 |
sc_rna_dataset = sc_data_loaders.SingleCellDataset( |
|
|
239 |
valid_cluster_id=args.validcluster, |
|
|
240 |
test_cluster_id=args.testcluster, |
|
|
241 |
**rna_data_kwargs, |
|
|
242 |
) |
|
|
243 |
|
|
|
244 |
sc_rna_train_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
245 |
sc_rna_dataset, split="train", |
|
|
246 |
) |
|
|
247 |
sc_rna_valid_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
248 |
sc_rna_dataset, split="valid", |
|
|
249 |
) |
|
|
250 |
sc_rna_test_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
251 |
sc_rna_dataset, split="test", |
|
|
252 |
) |
|
|
253 |
|
|
|
254 |
# ATAC |
|
|
255 |
logging.info("Aggregating ATAC clusters") |
|
|
256 |
if args.snareseq: |
|
|
257 |
atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS) |
|
|
258 |
elif args.shareseq: |
|
|
259 |
logging.info(f"Loading in SHAREseq ATAC data for {args.shareseq}") |
|
|
260 |
atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS) |
|
|
261 |
atac_data_kwargs["reader"] = None |
|
|
262 |
atac_data_kwargs["fname"] = None |
|
|
263 |
atac_data_kwargs["cell_info"] = None |
|
|
264 |
atac_data_kwargs["gene_info"] = None |
|
|
265 |
atac_data_kwargs["transpose"] = False |
|
|
266 |
atac_adatas = [] |
|
|
267 |
for tissuetype in args.shareseq: |
|
|
268 |
atac_adatas.append( |
|
|
269 |
adata_utils.load_shareseq_data( |
|
|
270 |
tissuetype, |
|
|
271 |
dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq", |
|
|
272 |
mode="ATAC", |
|
|
273 |
) |
|
|
274 |
) |
|
|
275 |
atac_bins = [a.var_names for a in atac_adatas] |
|
|
276 |
if len(atac_adatas) > 1: |
|
|
277 |
atac_bins_harmonized = sc_data_loaders.harmonize_atac_intervals(*atac_bins) |
|
|
278 |
atac_adatas = [ |
|
|
279 |
sc_data_loaders.repool_atac_bins(a, atac_bins_harmonized) |
|
|
280 |
for a in atac_adatas |
|
|
281 |
] |
|
|
282 |
shareseq_atac_adata = atac_adatas[0] |
|
|
283 |
if len(atac_adatas) > 1: |
|
|
284 |
shareseq_atac_adata = shareseq_atac_adata.concatenate( |
|
|
285 |
*atac_adatas[1:], |
|
|
286 |
join="inner", |
|
|
287 |
batch_key="tissue", |
|
|
288 |
batch_categories=args.shareseq, |
|
|
289 |
) |
|
|
290 |
atac_data_kwargs["raw_adata"] = shareseq_atac_adata |
|
|
291 |
else: |
|
|
292 |
atac_parsed = [ |
|
|
293 |
utils.sc_read_10x_h5_ft_type(fname, "Peaks") for fname in args.data |
|
|
294 |
] |
|
|
295 |
if len(atac_parsed) > 1: |
|
|
296 |
atac_bins = sc_data_loaders.harmonize_atac_intervals( |
|
|
297 |
atac_parsed[0].var_names, atac_parsed[1].var_names |
|
|
298 |
) |
|
|
299 |
for bins in atac_parsed[2:]: |
|
|
300 |
atac_bins = sc_data_loaders.harmonize_atac_intervals( |
|
|
301 |
atac_bins, bins.var_names |
|
|
302 |
) |
|
|
303 |
logging.info(f"Aggregated {len(atac_bins)} bins") |
|
|
304 |
else: |
|
|
305 |
atac_bins = list(atac_parsed[0].var_names) |
|
|
306 |
|
|
|
307 |
atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS) |
|
|
308 |
atac_data_kwargs["fname"] = rna_data_kwargs["fname"] |
|
|
309 |
atac_data_kwargs["pool_genomic_interval"] = 0 # Do not pool |
|
|
310 |
atac_data_kwargs["reader"] = functools.partial( |
|
|
311 |
utils.sc_read_multi_files, |
|
|
312 |
reader=lambda x: sc_data_loaders.repool_atac_bins( |
|
|
313 |
utils.sc_read_10x_h5_ft_type(x, "Peaks"), atac_bins, |
|
|
314 |
), |
|
|
315 |
) |
|
|
316 |
atac_data_kwargs["cluster_res"] = 0 # Do not bother clustering ATAC data |
|
|
317 |
|
|
|
318 |
sc_atac_dataset = sc_data_loaders.SingleCellDataset( |
|
|
319 |
predefined_split=sc_rna_dataset, **atac_data_kwargs |
|
|
320 |
) |
|
|
321 |
sc_atac_train_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
322 |
sc_atac_dataset, split="train", |
|
|
323 |
) |
|
|
324 |
sc_atac_valid_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
325 |
sc_atac_dataset, split="valid", |
|
|
326 |
) |
|
|
327 |
sc_atac_test_dataset = sc_data_loaders.SingleCellDatasetSplit( |
|
|
328 |
sc_atac_dataset, split="test", |
|
|
329 |
) |
|
|
330 |
|
|
|
331 |
sc_dual_train_dataset = sc_data_loaders.PairedDataset( |
|
|
332 |
sc_rna_train_dataset, sc_atac_train_dataset, flat_mode=True, |
|
|
333 |
) |
|
|
334 |
sc_dual_valid_dataset = sc_data_loaders.PairedDataset( |
|
|
335 |
sc_rna_valid_dataset, sc_atac_valid_dataset, flat_mode=True, |
|
|
336 |
) |
|
|
337 |
sc_dual_test_dataset = sc_data_loaders.PairedDataset( |
|
|
338 |
sc_rna_test_dataset, sc_atac_test_dataset, flat_mode=True, |
|
|
339 |
) |
|
|
340 |
sc_dual_full_dataset = sc_data_loaders.PairedDataset( |
|
|
341 |
sc_rna_dataset, sc_atac_dataset, flat_mode=True, |
|
|
342 |
) |
|
|
343 |
|
|
|
344 |
# Model |
|
|
345 |
param_combos = list( |
|
|
346 |
itertools.product( |
|
|
347 |
args.hidden, args.lossweight, args.lr, args.batchsize, args.seed |
|
|
348 |
) |
|
|
349 |
) |
|
|
350 |
for h_dim, lw, lr, bs, rand_seed in param_combos: |
|
|
351 |
outdir_name = ( |
|
|
352 |
f"{args.outdir}_hidden_{h_dim}_lossweight_{lw}_lr_{lr}_batchsize_{bs}_seed_{rand_seed}" |
|
|
353 |
if len(param_combos) > 1 |
|
|
354 |
else args.outdir |
|
|
355 |
) |
|
|
356 |
if not os.path.isdir(outdir_name): |
|
|
357 |
assert not os.path.exists(outdir_name) |
|
|
358 |
os.makedirs(outdir_name) |
|
|
359 |
assert os.path.isdir(outdir_name) |
|
|
360 |
with open(os.path.join(outdir_name, "rna_genes.txt"), "w") as sink: |
|
|
361 |
for gene in sc_rna_dataset.data_raw.var_names: |
|
|
362 |
sink.write(gene + "\n") |
|
|
363 |
with open(os.path.join(outdir_name, "atac_bins.txt"), "w") as sink: |
|
|
364 |
for atac_bin in sc_atac_dataset.data_raw.var_names: |
|
|
365 |
sink.write(atac_bin + "\n") |
|
|
366 |
|
|
|
367 |
# Write dataset |
|
|
368 |
### Full |
|
|
369 |
sc_rna_dataset.size_norm_counts.write_h5ad( |
|
|
370 |
os.path.join(outdir_name, "full_rna.h5ad") |
|
|
371 |
) |
|
|
372 |
sc_rna_dataset.size_norm_log_counts.write_h5ad( |
|
|
373 |
os.path.join(outdir_name, "full_rna_log.h5ad") |
|
|
374 |
) |
|
|
375 |
sc_atac_dataset.data_raw.write_h5ad(os.path.join(outdir_name, "full_atac.h5ad")) |
|
|
376 |
### Train |
|
|
377 |
sc_rna_train_dataset.size_norm_counts.write_h5ad( |
|
|
378 |
os.path.join(outdir_name, "train_rna.h5ad") |
|
|
379 |
) |
|
|
380 |
sc_atac_train_dataset.data_raw.write_h5ad( |
|
|
381 |
os.path.join(outdir_name, "train_atac.h5ad") |
|
|
382 |
) |
|
|
383 |
### Valid |
|
|
384 |
sc_rna_valid_dataset.size_norm_counts.write_h5ad( |
|
|
385 |
os.path.join(outdir_name, "valid_rna.h5ad") |
|
|
386 |
) |
|
|
387 |
sc_atac_valid_dataset.data_raw.write_h5ad( |
|
|
388 |
os.path.join(outdir_name, "valid_atac.h5ad") |
|
|
389 |
) |
|
|
390 |
### Test |
|
|
391 |
sc_rna_test_dataset.size_norm_counts.write_h5ad( |
|
|
392 |
os.path.join(outdir_name, "truth_rna.h5ad") |
|
|
393 |
) |
|
|
394 |
sc_atac_dataset.data_raw.write_h5ad(os.path.join(outdir_name, "full_atac.h5ad")) |
|
|
395 |
sc_atac_test_dataset.data_raw.write_h5ad( |
|
|
396 |
os.path.join(outdir_name, "truth_atac.h5ad") |
|
|
397 |
) |
|
|
398 |
|
|
|
399 |
# Instantiate and train model |
|
|
400 |
model_class = ( |
|
|
401 |
autoencoders.NaiveSplicedAutoEncoder |
|
|
402 |
if args.naive |
|
|
403 |
else autoencoders.AssymSplicedAutoEncoder |
|
|
404 |
) |
|
|
405 |
spliced_net = autoencoders.SplicedAutoEncoderSkorchNet( |
|
|
406 |
module=model_class, |
|
|
407 |
module__hidden_dim=h_dim, # Based on hyperparam tuning |
|
|
408 |
module__input_dim1=sc_rna_dataset.data_raw.shape[1], |
|
|
409 |
module__input_dim2=sc_atac_dataset.get_per_chrom_feature_count(), |
|
|
410 |
module__final_activations1=[ |
|
|
411 |
activations.Exp(), |
|
|
412 |
activations.ClippedSoftplus(), |
|
|
413 |
], |
|
|
414 |
module__final_activations2=nn.Sigmoid(), |
|
|
415 |
module__flat_mode=True, |
|
|
416 |
module__seed=rand_seed, |
|
|
417 |
lr=lr, # Based on hyperparam tuning |
|
|
418 |
criterion=loss_functions.QuadLoss, |
|
|
419 |
criterion__loss2=loss_functions.BCELoss, # handle output of encoded layer |
|
|
420 |
criterion__loss2_weight=lw, # numerically balance the two losses with different magnitudes |
|
|
421 |
criterion__record_history=True, |
|
|
422 |
optimizer=OPTIMIZER_DICT[args.optim], |
|
|
423 |
iterator_train__shuffle=True, |
|
|
424 |
device=utils.get_device(args.device), |
|
|
425 |
batch_size=bs, # Based on hyperparam tuning |
|
|
426 |
max_epochs=500, |
|
|
427 |
callbacks=[ |
|
|
428 |
skorch.callbacks.EarlyStopping(patience=args.earlystop), |
|
|
429 |
skorch.callbacks.LRScheduler( |
|
|
430 |
policy=torch.optim.lr_scheduler.ReduceLROnPlateau, |
|
|
431 |
**model_utils.REDUCE_LR_ON_PLATEAU_PARAMS, |
|
|
432 |
), |
|
|
433 |
skorch.callbacks.GradientNormClipping(gradient_clip_value=5), |
|
|
434 |
skorch.callbacks.Checkpoint( |
|
|
435 |
dirname=outdir_name, fn_prefix="net_", monitor="valid_loss_best", |
|
|
436 |
), |
|
|
437 |
], |
|
|
438 |
train_split=skorch.helper.predefined_split(sc_dual_valid_dataset), |
|
|
439 |
iterator_train__num_workers=8, |
|
|
440 |
iterator_valid__num_workers=8, |
|
|
441 |
) |
|
|
442 |
if args.pretrain: |
|
|
443 |
# Load in the warm start parameters |
|
|
444 |
spliced_net.load_params(f_params=args.pretrain) |
|
|
445 |
spliced_net.partial_fit(sc_dual_train_dataset, y=None) |
|
|
446 |
else: |
|
|
447 |
spliced_net.fit(sc_dual_train_dataset, y=None) |
|
|
448 |
|
|
|
449 |
fig = plot_loss_history( |
|
|
450 |
spliced_net.history, os.path.join(outdir_name, f"loss.{args.ext}") |
|
|
451 |
) |
|
|
452 |
plt.close(fig) |
|
|
453 |
|
|
|
454 |
logging.info("Evaluating on test set") |
|
|
455 |
logging.info("Evaluating RNA > RNA") |
|
|
456 |
sc_rna_test_preds = spliced_net.translate_1_to_1(sc_dual_test_dataset) |
|
|
457 |
sc_rna_test_preds_anndata = sc.AnnData( |
|
|
458 |
sc_rna_test_preds, |
|
|
459 |
var=sc_rna_test_dataset.data_raw.var, |
|
|
460 |
obs=sc_rna_test_dataset.data_raw.obs, |
|
|
461 |
) |
|
|
462 |
sc_rna_test_preds_anndata.write_h5ad( |
|
|
463 |
os.path.join(outdir_name, "rna_rna_test_preds.h5ad") |
|
|
464 |
) |
|
|
465 |
fig = plot_utils.plot_scatter_with_r( |
|
|
466 |
sc_rna_test_dataset.size_norm_counts.X, |
|
|
467 |
sc_rna_test_preds, |
|
|
468 |
one_to_one=True, |
|
|
469 |
logscale=True, |
|
|
470 |
density_heatmap=True, |
|
|
471 |
title="RNA > RNA (test set)", |
|
|
472 |
fname=os.path.join(outdir_name, f"rna_rna_scatter_log.{args.ext}"), |
|
|
473 |
) |
|
|
474 |
plt.close(fig) |
|
|
475 |
|
|
|
476 |
logging.info("Evaluating ATAC > ATAC") |
|
|
477 |
sc_atac_test_preds = spliced_net.translate_2_to_2(sc_dual_test_dataset) |
|
|
478 |
sc_atac_test_preds_anndata = sc.AnnData( |
|
|
479 |
sc_atac_test_preds, |
|
|
480 |
var=sc_atac_test_dataset.data_raw.var, |
|
|
481 |
obs=sc_atac_test_dataset.data_raw.obs, |
|
|
482 |
) |
|
|
483 |
sc_atac_test_preds_anndata.write_h5ad( |
|
|
484 |
os.path.join(outdir_name, "atac_atac_test_preds.h5ad") |
|
|
485 |
) |
|
|
486 |
fig = plot_utils.plot_auroc( |
|
|
487 |
sc_atac_test_dataset.data_raw.X, |
|
|
488 |
sc_atac_test_preds, |
|
|
489 |
title_prefix="ATAC > ATAC", |
|
|
490 |
fname=os.path.join(outdir_name, f"atac_atac_auroc.{args.ext}"), |
|
|
491 |
) |
|
|
492 |
plt.close(fig) |
|
|
493 |
|
|
|
494 |
logging.info("Evaluating ATAC > RNA") |
|
|
495 |
sc_atac_rna_test_preds = spliced_net.translate_2_to_1(sc_dual_test_dataset) |
|
|
496 |
sc_atac_rna_test_preds_anndata = sc.AnnData( |
|
|
497 |
sc_atac_rna_test_preds, |
|
|
498 |
var=sc_rna_test_dataset.data_raw.var, |
|
|
499 |
obs=sc_rna_test_dataset.data_raw.obs, |
|
|
500 |
) |
|
|
501 |
sc_atac_rna_test_preds_anndata.write_h5ad( |
|
|
502 |
os.path.join(outdir_name, "atac_rna_test_preds.h5ad") |
|
|
503 |
) |
|
|
504 |
fig = plot_utils.plot_scatter_with_r( |
|
|
505 |
sc_rna_test_dataset.size_norm_counts.X, |
|
|
506 |
sc_atac_rna_test_preds, |
|
|
507 |
one_to_one=True, |
|
|
508 |
logscale=True, |
|
|
509 |
density_heatmap=True, |
|
|
510 |
title="ATAC > RNA (test set)", |
|
|
511 |
fname=os.path.join(outdir_name, f"atac_rna_scatter_log.{args.ext}"), |
|
|
512 |
) |
|
|
513 |
plt.close(fig) |
|
|
514 |
|
|
|
515 |
logging.info("Evaluating RNA > ATAC") |
|
|
516 |
sc_rna_atac_test_preds = spliced_net.translate_1_to_2(sc_dual_test_dataset) |
|
|
517 |
sc_rna_atac_test_preds_anndata = sc.AnnData( |
|
|
518 |
sc_rna_atac_test_preds, |
|
|
519 |
var=sc_atac_test_dataset.data_raw.var, |
|
|
520 |
obs=sc_atac_test_dataset.data_raw.obs, |
|
|
521 |
) |
|
|
522 |
sc_rna_atac_test_preds_anndata.write_h5ad( |
|
|
523 |
os.path.join(outdir_name, "rna_atac_test_preds.h5ad") |
|
|
524 |
) |
|
|
525 |
fig = plot_utils.plot_auroc( |
|
|
526 |
sc_atac_test_dataset.data_raw.X, |
|
|
527 |
sc_rna_atac_test_preds, |
|
|
528 |
title_prefix="RNA > ATAC", |
|
|
529 |
fname=os.path.join(outdir_name, f"rna_atac_auroc.{args.ext}"), |
|
|
530 |
) |
|
|
531 |
plt.close(fig) |
|
|
532 |
|
|
|
533 |
del spliced_net |
|
|
534 |
|
|
|
535 |
|
|
|
536 |
if __name__ == "__main__": |
|
|
537 |
main() |