Diff of /bin/train_model.py [000000] .. [d01132]

Switch to unified view

a b/bin/train_model.py
1
"""
2
Code to train a model
3
"""
4
5
import os
6
import sys
7
import logging
8
import argparse
9
import copy
10
import functools
11
import itertools
12
13
import numpy as np
14
import pandas as pd
15
import scipy.spatial
16
import scanpy as sc
17
18
import matplotlib.pyplot as plt
19
from skorch.helper import predefined_split
20
21
import torch
22
import torch.nn as nn
23
import torch.nn.functional as F
24
import skorch
25
import skorch.helper
26
27
torch.backends.cudnn.deterministic = True  # For reproducibility
28
torch.backends.cudnn.benchmark = False
29
30
SRC_DIR = os.path.join(
31
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel"
32
)
33
assert os.path.isdir(SRC_DIR)
34
sys.path.append(SRC_DIR)
35
36
MODELS_DIR = os.path.join(SRC_DIR, "models")
37
assert os.path.isdir(MODELS_DIR)
38
sys.path.append(MODELS_DIR)
39
40
import sc_data_loaders
41
import adata_utils
42
import model_utils
43
import autoencoders
44
import loss_functions
45
import layers
46
import activations
47
import plot_utils
48
import utils
49
import metrics
50
import interpretation
51
52
logging.basicConfig(level=logging.INFO)
53
54
OPTIMIZER_DICT = {
55
    "adam": torch.optim.Adam,
56
    "rmsprop": torch.optim.RMSprop,
57
}
58
59
60
def build_parser():
61
    """Build argument parser"""
62
    parser = argparse.ArgumentParser(
63
        description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
64
    )
65
    input_group = parser.add_mutually_exclusive_group(required=True)
66
    input_group.add_argument(
67
        "--data", "-d", type=str, nargs="*", help="Data files to train on",
68
    )
69
    input_group.add_argument(
70
        "--snareseq",
71
        action="store_true",
72
        help="Data in SNAREseq format, use custom data loading logic for separated RNA ATC files",
73
    )
74
    input_group.add_argument(
75
        "--shareseq",
76
        nargs="+",
77
        type=str,
78
        choices=["lung", "skin", "brain"],
79
        help="Load in the given SHAREseq datasets",
80
    )
81
    parser.add_argument(
82
        "--nofilter",
83
        action="store_true",
84
        help="Whether or not to perform filtering (only applies with --data argument)",
85
    )
86
    parser.add_argument(
87
        "--linear",
88
        action="store_true",
89
        help="Do clustering data splitting in linear instead of log space",
90
    )
91
    parser.add_argument(
92
        "--clustermethod",
93
        type=str,
94
        choices=["leiden", "louvain"],
95
        default="leiden",
96
        help="Clustering method to determine data splits",
97
    )
98
    parser.add_argument(
99
        "--validcluster", type=int, default=0, help="Cluster ID to use as valid cluster"
100
    )
101
    parser.add_argument(
102
        "--testcluster", type=int, default=1, help="Cluster ID to use as test cluster"
103
    )
104
    parser.add_argument(
105
        "--outdir", "-o", required=True, type=str, help="Directory to output to"
106
    )
107
    parser.add_argument(
108
        "--naive",
109
        "-n",
110
        action="store_true",
111
        help="Use a naive model instead of lego model",
112
    )
113
    parser.add_argument(
114
        "--hidden", type=int, nargs="*", default=[16], help="Hidden dimensions"
115
    )
116
    parser.add_argument(
117
        "--pretrain",
118
        type=str,
119
        default="",
120
        help="params.pt file to use to warm initialize the model (instead of starting from scratch)",
121
    )
122
    parser.add_argument(
123
        "--lossweight",
124
        type=float,
125
        nargs="*",
126
        default=[1.33],
127
        help="Relative loss weight",
128
    )
129
    parser.add_argument(
130
        "--optim",
131
        type=str,
132
        default="adam",
133
        choices=OPTIMIZER_DICT.keys(),
134
        help="Optimizer to use",
135
    )
136
    parser.add_argument(
137
        "--lr", "-l", type=float, default=[0.01], nargs="*", help="Learning rate"
138
    )
139
    parser.add_argument(
140
        "--batchsize", "-b", type=int, nargs="*", default=[512], help="Batch size"
141
    )
142
    parser.add_argument(
143
        "--earlystop", type=int, default=25, help="Early stopping after N epochs"
144
    )
145
    parser.add_argument(
146
        "--seed", type=int, nargs="*", default=[182822], help="Random seed to use"
147
    )
148
    parser.add_argument("--device", default=0, type=int, help="Device to train on")
149
    parser.add_argument(
150
        "--ext",
151
        type=str,
152
        choices=["png", "pdf", "jpg"],
153
        default="pdf",
154
        help="Output format for plots",
155
    )
156
    return parser
157
158
159
def plot_loss_history(history, fname: str):
160
    """Create a plot of train valid loss"""
161
    fig, ax = plt.subplots(dpi=300)
162
    ax.plot(
163
        np.arange(len(history)), history[:, "train_loss"], label="Train",
164
    )
165
    ax.plot(
166
        np.arange(len(history)), history[:, "valid_loss"], label="Valid",
167
    )
168
    ax.legend()
169
    ax.set(
170
        xlabel="Epoch", ylabel="Loss",
171
    )
172
    fig.savefig(fname)
173
    return fig
174
175
176
def main():
177
    """Run the script"""
178
    parser = build_parser()
179
    args = parser.parse_args()
180
    args.outdir = os.path.abspath(args.outdir)
181
182
    if not os.path.isdir(os.path.dirname(args.outdir)):
183
        os.makedirs(os.path.dirname(args.outdir))
184
185
    # Specify output log file
186
    logger = logging.getLogger()
187
    fh = logging.FileHandler(f"{args.outdir}_training.log", "w")
188
    fh.setLevel(logging.INFO)
189
    logger.addHandler(fh)
190
191
    # Log parameters and pytorch version
192
    if torch.cuda.is_available():
193
        logging.info(f"PyTorch CUDA version: {torch.version.cuda}")
194
    for arg in vars(args):
195
        logging.info(f"Parameter {arg}: {getattr(args, arg)}")
196
197
    # Borrow parameters
198
    logging.info("Reading RNA data")
199
    if args.snareseq:
200
        rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS)
201
    elif args.shareseq:
202
        logging.info(f"Loading in SHAREseq RNA data for: {args.shareseq}")
203
        rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS)
204
        rna_data_kwargs["fname"] = None
205
        rna_data_kwargs["reader"] = None
206
        rna_data_kwargs["cell_info"] = None
207
        rna_data_kwargs["gene_info"] = None
208
        rna_data_kwargs["transpose"] = False
209
        # Load in the datasets
210
        shareseq_rna_adatas = []
211
        for tissuetype in args.shareseq:
212
            shareseq_rna_adatas.append(
213
                adata_utils.load_shareseq_data(
214
                    tissuetype,
215
                    dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq",
216
                    mode="RNA",
217
                )
218
            )
219
        shareseq_rna_adata = shareseq_rna_adatas[0]
220
        if len(shareseq_rna_adatas) > 1:
221
            shareseq_rna_adata = shareseq_rna_adata.concatenate(
222
                *shareseq_rna_adatas[1:],
223
                join="inner",
224
                batch_key="tissue",
225
                batch_categories=args.shareseq,
226
            )
227
        rna_data_kwargs["raw_adata"] = shareseq_rna_adata
228
    else:
229
        rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
230
        rna_data_kwargs["fname"] = args.data
231
        if args.nofilter:
232
            rna_data_kwargs = {
233
                k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_")
234
            }
235
    rna_data_kwargs["data_split_by_cluster_log"] = not args.linear
236
    rna_data_kwargs["data_split_by_cluster"] = args.clustermethod
237
238
    sc_rna_dataset = sc_data_loaders.SingleCellDataset(
239
        valid_cluster_id=args.validcluster,
240
        test_cluster_id=args.testcluster,
241
        **rna_data_kwargs,
242
    )
243
244
    sc_rna_train_dataset = sc_data_loaders.SingleCellDatasetSplit(
245
        sc_rna_dataset, split="train",
246
    )
247
    sc_rna_valid_dataset = sc_data_loaders.SingleCellDatasetSplit(
248
        sc_rna_dataset, split="valid",
249
    )
250
    sc_rna_test_dataset = sc_data_loaders.SingleCellDatasetSplit(
251
        sc_rna_dataset, split="test",
252
    )
253
254
    # ATAC
255
    logging.info("Aggregating ATAC clusters")
256
    if args.snareseq:
257
        atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS)
258
    elif args.shareseq:
259
        logging.info(f"Loading in SHAREseq ATAC data for {args.shareseq}")
260
        atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS)
261
        atac_data_kwargs["reader"] = None
262
        atac_data_kwargs["fname"] = None
263
        atac_data_kwargs["cell_info"] = None
264
        atac_data_kwargs["gene_info"] = None
265
        atac_data_kwargs["transpose"] = False
266
        atac_adatas = []
267
        for tissuetype in args.shareseq:
268
            atac_adatas.append(
269
                adata_utils.load_shareseq_data(
270
                    tissuetype,
271
                    dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq",
272
                    mode="ATAC",
273
                )
274
            )
275
        atac_bins = [a.var_names for a in atac_adatas]
276
        if len(atac_adatas) > 1:
277
            atac_bins_harmonized = sc_data_loaders.harmonize_atac_intervals(*atac_bins)
278
            atac_adatas = [
279
                sc_data_loaders.repool_atac_bins(a, atac_bins_harmonized)
280
                for a in atac_adatas
281
            ]
282
        shareseq_atac_adata = atac_adatas[0]
283
        if len(atac_adatas) > 1:
284
            shareseq_atac_adata = shareseq_atac_adata.concatenate(
285
                *atac_adatas[1:],
286
                join="inner",
287
                batch_key="tissue",
288
                batch_categories=args.shareseq,
289
            )
290
        atac_data_kwargs["raw_adata"] = shareseq_atac_adata
291
    else:
292
        atac_parsed = [
293
            utils.sc_read_10x_h5_ft_type(fname, "Peaks") for fname in args.data
294
        ]
295
        if len(atac_parsed) > 1:
296
            atac_bins = sc_data_loaders.harmonize_atac_intervals(
297
                atac_parsed[0].var_names, atac_parsed[1].var_names
298
            )
299
            for bins in atac_parsed[2:]:
300
                atac_bins = sc_data_loaders.harmonize_atac_intervals(
301
                    atac_bins, bins.var_names
302
                )
303
            logging.info(f"Aggregated {len(atac_bins)} bins")
304
        else:
305
            atac_bins = list(atac_parsed[0].var_names)
306
307
        atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS)
308
        atac_data_kwargs["fname"] = rna_data_kwargs["fname"]
309
        atac_data_kwargs["pool_genomic_interval"] = 0  # Do not pool
310
        atac_data_kwargs["reader"] = functools.partial(
311
            utils.sc_read_multi_files,
312
            reader=lambda x: sc_data_loaders.repool_atac_bins(
313
                utils.sc_read_10x_h5_ft_type(x, "Peaks"), atac_bins,
314
            ),
315
        )
316
    atac_data_kwargs["cluster_res"] = 0  # Do not bother clustering ATAC data
317
318
    sc_atac_dataset = sc_data_loaders.SingleCellDataset(
319
        predefined_split=sc_rna_dataset, **atac_data_kwargs
320
    )
321
    sc_atac_train_dataset = sc_data_loaders.SingleCellDatasetSplit(
322
        sc_atac_dataset, split="train",
323
    )
324
    sc_atac_valid_dataset = sc_data_loaders.SingleCellDatasetSplit(
325
        sc_atac_dataset, split="valid",
326
    )
327
    sc_atac_test_dataset = sc_data_loaders.SingleCellDatasetSplit(
328
        sc_atac_dataset, split="test",
329
    )
330
331
    sc_dual_train_dataset = sc_data_loaders.PairedDataset(
332
        sc_rna_train_dataset, sc_atac_train_dataset, flat_mode=True,
333
    )
334
    sc_dual_valid_dataset = sc_data_loaders.PairedDataset(
335
        sc_rna_valid_dataset, sc_atac_valid_dataset, flat_mode=True,
336
    )
337
    sc_dual_test_dataset = sc_data_loaders.PairedDataset(
338
        sc_rna_test_dataset, sc_atac_test_dataset, flat_mode=True,
339
    )
340
    sc_dual_full_dataset = sc_data_loaders.PairedDataset(
341
        sc_rna_dataset, sc_atac_dataset, flat_mode=True,
342
    )
343
344
    # Model
345
    param_combos = list(
346
        itertools.product(
347
            args.hidden, args.lossweight, args.lr, args.batchsize, args.seed
348
        )
349
    )
350
    for h_dim, lw, lr, bs, rand_seed in param_combos:
351
        outdir_name = (
352
            f"{args.outdir}_hidden_{h_dim}_lossweight_{lw}_lr_{lr}_batchsize_{bs}_seed_{rand_seed}"
353
            if len(param_combos) > 1
354
            else args.outdir
355
        )
356
        if not os.path.isdir(outdir_name):
357
            assert not os.path.exists(outdir_name)
358
            os.makedirs(outdir_name)
359
        assert os.path.isdir(outdir_name)
360
        with open(os.path.join(outdir_name, "rna_genes.txt"), "w") as sink:
361
            for gene in sc_rna_dataset.data_raw.var_names:
362
                sink.write(gene + "\n")
363
        with open(os.path.join(outdir_name, "atac_bins.txt"), "w") as sink:
364
            for atac_bin in sc_atac_dataset.data_raw.var_names:
365
                sink.write(atac_bin + "\n")
366
367
        # Write dataset
368
        ### Full
369
        sc_rna_dataset.size_norm_counts.write_h5ad(
370
            os.path.join(outdir_name, "full_rna.h5ad")
371
        )
372
        sc_rna_dataset.size_norm_log_counts.write_h5ad(
373
            os.path.join(outdir_name, "full_rna_log.h5ad")
374
        )
375
        sc_atac_dataset.data_raw.write_h5ad(os.path.join(outdir_name, "full_atac.h5ad"))
376
        ### Train
377
        sc_rna_train_dataset.size_norm_counts.write_h5ad(
378
            os.path.join(outdir_name, "train_rna.h5ad")
379
        )
380
        sc_atac_train_dataset.data_raw.write_h5ad(
381
            os.path.join(outdir_name, "train_atac.h5ad")
382
        )
383
        ### Valid
384
        sc_rna_valid_dataset.size_norm_counts.write_h5ad(
385
            os.path.join(outdir_name, "valid_rna.h5ad")
386
        )
387
        sc_atac_valid_dataset.data_raw.write_h5ad(
388
            os.path.join(outdir_name, "valid_atac.h5ad")
389
        )
390
        ### Test
391
        sc_rna_test_dataset.size_norm_counts.write_h5ad(
392
            os.path.join(outdir_name, "truth_rna.h5ad")
393
        )
394
        sc_atac_dataset.data_raw.write_h5ad(os.path.join(outdir_name, "full_atac.h5ad"))
395
        sc_atac_test_dataset.data_raw.write_h5ad(
396
            os.path.join(outdir_name, "truth_atac.h5ad")
397
        )
398
399
        # Instantiate and train model
400
        model_class = (
401
            autoencoders.NaiveSplicedAutoEncoder
402
            if args.naive
403
            else autoencoders.AssymSplicedAutoEncoder
404
        )
405
        spliced_net = autoencoders.SplicedAutoEncoderSkorchNet(
406
            module=model_class,
407
            module__hidden_dim=h_dim,  # Based on hyperparam tuning
408
            module__input_dim1=sc_rna_dataset.data_raw.shape[1],
409
            module__input_dim2=sc_atac_dataset.get_per_chrom_feature_count(),
410
            module__final_activations1=[
411
                activations.Exp(),
412
                activations.ClippedSoftplus(),
413
            ],
414
            module__final_activations2=nn.Sigmoid(),
415
            module__flat_mode=True,
416
            module__seed=rand_seed,
417
            lr=lr,  # Based on hyperparam tuning
418
            criterion=loss_functions.QuadLoss,
419
            criterion__loss2=loss_functions.BCELoss,  # handle output of encoded layer
420
            criterion__loss2_weight=lw,  # numerically balance the two losses with different magnitudes
421
            criterion__record_history=True,
422
            optimizer=OPTIMIZER_DICT[args.optim],
423
            iterator_train__shuffle=True,
424
            device=utils.get_device(args.device),
425
            batch_size=bs,  # Based on  hyperparam tuning
426
            max_epochs=500,
427
            callbacks=[
428
                skorch.callbacks.EarlyStopping(patience=args.earlystop),
429
                skorch.callbacks.LRScheduler(
430
                    policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
431
                    **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS,
432
                ),
433
                skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
434
                skorch.callbacks.Checkpoint(
435
                    dirname=outdir_name, fn_prefix="net_", monitor="valid_loss_best",
436
                ),
437
            ],
438
            train_split=skorch.helper.predefined_split(sc_dual_valid_dataset),
439
            iterator_train__num_workers=8,
440
            iterator_valid__num_workers=8,
441
        )
442
        if args.pretrain:
443
            # Load in the warm start parameters
444
            spliced_net.load_params(f_params=args.pretrain)
445
            spliced_net.partial_fit(sc_dual_train_dataset, y=None)
446
        else:
447
            spliced_net.fit(sc_dual_train_dataset, y=None)
448
449
        fig = plot_loss_history(
450
            spliced_net.history, os.path.join(outdir_name, f"loss.{args.ext}")
451
        )
452
        plt.close(fig)
453
454
        logging.info("Evaluating on test set")
455
        logging.info("Evaluating RNA > RNA")
456
        sc_rna_test_preds = spliced_net.translate_1_to_1(sc_dual_test_dataset)
457
        sc_rna_test_preds_anndata = sc.AnnData(
458
            sc_rna_test_preds,
459
            var=sc_rna_test_dataset.data_raw.var,
460
            obs=sc_rna_test_dataset.data_raw.obs,
461
        )
462
        sc_rna_test_preds_anndata.write_h5ad(
463
            os.path.join(outdir_name, "rna_rna_test_preds.h5ad")
464
        )
465
        fig = plot_utils.plot_scatter_with_r(
466
            sc_rna_test_dataset.size_norm_counts.X,
467
            sc_rna_test_preds,
468
            one_to_one=True,
469
            logscale=True,
470
            density_heatmap=True,
471
            title="RNA > RNA (test set)",
472
            fname=os.path.join(outdir_name, f"rna_rna_scatter_log.{args.ext}"),
473
        )
474
        plt.close(fig)
475
476
        logging.info("Evaluating ATAC > ATAC")
477
        sc_atac_test_preds = spliced_net.translate_2_to_2(sc_dual_test_dataset)
478
        sc_atac_test_preds_anndata = sc.AnnData(
479
            sc_atac_test_preds,
480
            var=sc_atac_test_dataset.data_raw.var,
481
            obs=sc_atac_test_dataset.data_raw.obs,
482
        )
483
        sc_atac_test_preds_anndata.write_h5ad(
484
            os.path.join(outdir_name, "atac_atac_test_preds.h5ad")
485
        )
486
        fig = plot_utils.plot_auroc(
487
            sc_atac_test_dataset.data_raw.X,
488
            sc_atac_test_preds,
489
            title_prefix="ATAC > ATAC",
490
            fname=os.path.join(outdir_name, f"atac_atac_auroc.{args.ext}"),
491
        )
492
        plt.close(fig)
493
494
        logging.info("Evaluating ATAC > RNA")
495
        sc_atac_rna_test_preds = spliced_net.translate_2_to_1(sc_dual_test_dataset)
496
        sc_atac_rna_test_preds_anndata = sc.AnnData(
497
            sc_atac_rna_test_preds,
498
            var=sc_rna_test_dataset.data_raw.var,
499
            obs=sc_rna_test_dataset.data_raw.obs,
500
        )
501
        sc_atac_rna_test_preds_anndata.write_h5ad(
502
            os.path.join(outdir_name, "atac_rna_test_preds.h5ad")
503
        )
504
        fig = plot_utils.plot_scatter_with_r(
505
            sc_rna_test_dataset.size_norm_counts.X,
506
            sc_atac_rna_test_preds,
507
            one_to_one=True,
508
            logscale=True,
509
            density_heatmap=True,
510
            title="ATAC > RNA (test set)",
511
            fname=os.path.join(outdir_name, f"atac_rna_scatter_log.{args.ext}"),
512
        )
513
        plt.close(fig)
514
515
        logging.info("Evaluating RNA > ATAC")
516
        sc_rna_atac_test_preds = spliced_net.translate_1_to_2(sc_dual_test_dataset)
517
        sc_rna_atac_test_preds_anndata = sc.AnnData(
518
            sc_rna_atac_test_preds,
519
            var=sc_atac_test_dataset.data_raw.var,
520
            obs=sc_atac_test_dataset.data_raw.obs,
521
        )
522
        sc_rna_atac_test_preds_anndata.write_h5ad(
523
            os.path.join(outdir_name, "rna_atac_test_preds.h5ad")
524
        )
525
        fig = plot_utils.plot_auroc(
526
            sc_atac_test_dataset.data_raw.X,
527
            sc_rna_atac_test_preds,
528
            title_prefix="RNA > ATAC",
529
            fname=os.path.join(outdir_name, f"rna_atac_auroc.{args.ext}"),
530
        )
531
        plt.close(fig)
532
533
        del spliced_net
534
535
536
if __name__ == "__main__":
537
    main()